use unpack graph primitive instead add testcases for all grad interface remove debug log format code remove dumpfuncgraph resolve clang-format resolve reviews resolve cpplint fix reviewtags/v0.3.0-alpha
| @@ -1199,51 +1199,6 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec | |||||
| return ret_graph; | return ret_graph; | ||||
| } | } | ||||
| FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||||
| // slice a tensor | |||||
| // args: tensor, slice or slice tuple | |||||
| const std::string op_name = std::string("UnpackCall"); | |||||
| size_t arg_length = args_spec_list.size(); | |||||
| if (arg_length < 2) { | |||||
| MS_LOG(EXCEPTION) << "" << op_name << " requires at least two args, but got " << arg_length << "."; | |||||
| } | |||||
| (void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0); | |||||
| FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | |||||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| AnfNodePtr fnNode = ret_graph->add_parameter(); | |||||
| std::vector<AnfNodePtr> elems; | |||||
| elems.push_back(fnNode); | |||||
| for (size_t index = 1; index < arg_length; index++) { | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[index]); | |||||
| if (args_spec_list[index]->isa<AbstractTuple>()) { | |||||
| AbstractTuplePtr arg_tuple = dyn_cast<AbstractTuple>(args_spec_list[index]); | |||||
| AnfNodePtr para_tuple = ret_graph->add_parameter(); | |||||
| for (size_t i = 0; i < arg_tuple->size(); ++i) { | |||||
| elems.push_back( | |||||
| ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))})); | |||||
| } | |||||
| } else if (args_spec_list[index]->isa<AbstractDictionary>()) { | |||||
| AbstractDictionaryPtr arg_dict = dyn_cast<AbstractDictionary>(args_spec_list[index]); | |||||
| AnfNodePtr para_dict = ret_graph->add_parameter(); | |||||
| auto dict_elems = arg_dict->elements(); | |||||
| (void)std::transform( | |||||
| dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), | |||||
| [ret_graph, para_dict](const AbstractAttribute& item) { | |||||
| return ret_graph->NewCNode( | |||||
| {NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), | |||||
| ret_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)})}); | |||||
| }); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "" << op_name << " require args should be tuple or dict, but got " | |||||
| << args_spec_list[index]->ToString(); | |||||
| } | |||||
| } | |||||
| ret_graph->set_output(ret_graph->NewCNode(elems)); | |||||
| return ret_graph; | |||||
| } | |||||
| REGISTER_PYBIND_DEFINE( | REGISTER_PYBIND_DEFINE( | ||||
| TupleAdd_, ([](const py::module* m) { | TupleAdd_, ([](const py::module* m) { | ||||
| (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_").def(py::init<std::string&>()); | (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_").def(py::init<std::string&>()); | ||||
| @@ -1258,10 +1213,5 @@ REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module* m) { | |||||
| (void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_") | (void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_") | ||||
| .def(py::init<std::string&>()); | .def(py::init<std::string&>()); | ||||
| })); | })); | ||||
| REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) { | |||||
| (void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_") | |||||
| .def(py::init<std::string&>()); | |||||
| })); | |||||
| } // namespace prim | } // namespace prim | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "operator/composite/zip_operation.h" | #include "operator/composite/zip_operation.h" | ||||
| #include "operator/composite/list_append_operation.h" | #include "operator/composite/list_append_operation.h" | ||||
| #include "operator/composite/do_signature.h" | #include "operator/composite/do_signature.h" | ||||
| #include "operator/composite/unpack_call.h" | |||||
| #include "pipeline/static_analysis/static_analysis.h" | #include "pipeline/static_analysis/static_analysis.h" | ||||
| #include "utils/misc.h" | #include "utils/misc.h" | ||||
| #include "utils/any.h" | #include "utils/any.h" | ||||
| @@ -154,7 +155,7 @@ class GradOperation : public MetaFuncGraph { | |||||
| FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector<AnfNodePtr>& ptrParams, | FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector<AnfNodePtr>& ptrParams, | ||||
| bool applyJ = false); | bool applyJ = false); | ||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | ||||
| bool sens_param() const { return sens_param_; } | |||||
| bool get_all_; | bool get_all_; | ||||
| bool get_by_list_; | bool get_by_list_; | ||||
| bool sens_param_; | bool sens_param_; | ||||
| @@ -208,17 +209,6 @@ class TensorSlice : public MetaFuncGraph { | |||||
| }; | }; | ||||
| using TensorSlicePtr = std::shared_ptr<TensorSlice>; | using TensorSlicePtr = std::shared_ptr<TensorSlice>; | ||||
| // Expand the tuple and dict parameters generated when parsing the function call, | |||||
| // and generate positional parameters and key-value pairs for function. | |||||
| class UnpackCall : public MetaFuncGraph { | |||||
| public: | |||||
| explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {} | |||||
| ~UnpackCall() override = default; | |||||
| MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) | |||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||||
| friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; } | |||||
| }; | |||||
| using UnpackCallPtr = std::shared_ptr<UnpackCall>; | |||||
| } // namespace prim | } // namespace prim | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,94 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "operator/composite/unpack_call.h" | |||||
| #include <algorithm> | |||||
| #include <utility> | |||||
| #include "./common.h" | |||||
| #include "pipeline/static_analysis/abstract_value.h" | |||||
| #include "pipeline/static_analysis/dshape.h" | |||||
| #include "pipeline/static_analysis/param_validator.h" | |||||
| #include "operator/cc_implementations.h" | |||||
| #include "ir/anf.h" | |||||
| #include "optimizer/opt.h" | |||||
| #include "utils/symbolic.h" | |||||
| #include "pybind_api/api_register.h" | |||||
| namespace mindspore { | |||||
| // namespace to support composite operators definition | |||||
| namespace prim { | |||||
| using mindspore::abstract::AbstractAttribute; | |||||
| using mindspore::abstract::AbstractBase; | |||||
| using mindspore::abstract::AbstractDictionary; | |||||
| using mindspore::abstract::AbstractDictionaryPtr; | |||||
| using mindspore::abstract::AbstractFunction; | |||||
| using mindspore::abstract::AbstractKeywordArg; | |||||
| using mindspore::abstract::AbstractTuple; | |||||
| using mindspore::abstract::AbstractTuplePtr; | |||||
| FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { | |||||
| // slice a tensor | |||||
| // args: tensor, slice or slice tuple | |||||
| const std::string op_name = std::string("UnpackCall"); | |||||
| size_t arg_length = args_spec_list.size(); | |||||
| if (arg_length < 2) { | |||||
| MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << "."; | |||||
| } | |||||
| (void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0); | |||||
| auto ret_graph = std::make_shared<FuncGraph>(); | |||||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| AnfNodePtr fnNode = ret_graph->add_parameter(); | |||||
| std::vector<AnfNodePtr> elems; | |||||
| elems.push_back(fnNode); | |||||
| for (size_t index = 1; index < arg_length; index++) { | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[index]); | |||||
| if (args_spec_list[index]->isa<AbstractTuple>()) { | |||||
| auto arg_tuple = args_spec_list[index]->cast<AbstractTuplePtr>(); | |||||
| AnfNodePtr para_tuple = ret_graph->add_parameter(); | |||||
| for (size_t i = 0; i < arg_tuple->size(); ++i) { | |||||
| elems.push_back( | |||||
| ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))})); | |||||
| } | |||||
| } else if (args_spec_list[index]->isa<AbstractDictionary>()) { | |||||
| AbstractDictionaryPtr arg_dict = args_spec_list[index]->cast<AbstractDictionaryPtr>(); | |||||
| AnfNodePtr para_dict = ret_graph->add_parameter(); | |||||
| auto dict_elems = arg_dict->elements(); | |||||
| (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), | |||||
| [ret_graph, para_dict](const AbstractAttribute& item) { | |||||
| auto dict_get_item = ret_graph->NewCNode( | |||||
| {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)}); | |||||
| return ret_graph->NewCNode( | |||||
| {NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), dict_get_item}); | |||||
| }); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << op_name << " require args should be tuple or dict, but got " | |||||
| << args_spec_list[index]->ToString(); | |||||
| } | |||||
| } | |||||
| ret_graph->set_output(ret_graph->NewCNode(elems)); | |||||
| return ret_graph; | |||||
| } | |||||
| REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) { | |||||
| (void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_") | |||||
| .def(py::init<std::string&>()); | |||||
| })); | |||||
| } // namespace prim | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ | |||||
| #define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include <map> | |||||
| #include <set> | |||||
| #include <memory> | |||||
| #include "pipeline/static_analysis/static_analysis.h" | |||||
| #include "utils/misc.h" | |||||
| #include "utils/any.h" | |||||
| #include "ir/dtype.h" | |||||
| #include "ir/meta_func_graph.h" | |||||
| #include "common/utils.h" | |||||
| namespace mindspore { | |||||
| // namespace to support composite operators definition | |||||
| namespace prim { | |||||
| // Expand the tuple and dict parameters generated when parsing the function call, | |||||
| // and generate positional parameters and key-value pairs for function. | |||||
| class UnpackCall : public MetaFuncGraph { | |||||
| public: | |||||
| explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {} | |||||
| ~UnpackCall() override = default; | |||||
| MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) | |||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; | |||||
| friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; } | |||||
| }; | |||||
| using UnpackCallPtr = std::shared_ptr<UnpackCall>; | |||||
| } // namespace prim | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ | |||||
| @@ -246,6 +246,21 @@ class DoSignaturePrimitive : public Primitive { | |||||
| ValuePtr function_; | ValuePtr function_; | ||||
| }; | }; | ||||
| using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>; | using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>; | ||||
| class UnpackGraphPrimitive : public Primitive { | |||||
| public: | |||||
| explicit UnpackGraphPrimitive(const std::string& name, const bool& with_sens, const bool& need_unpack_args) | |||||
| : Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {} | |||||
| ~UnpackGraphPrimitive() override = default; | |||||
| MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive) | |||||
| bool with_sens_in_args() const { return with_sens_in_args_; } | |||||
| bool need_unpack_args() const { return need_unpack_args_; } | |||||
| private: | |||||
| bool with_sens_in_args_; | |||||
| bool need_unpack_args_; | |||||
| }; | |||||
| using UnpackGraphPrimitivePtr = std::shared_ptr<UnpackGraphPrimitive>; | |||||
| } // namespace prim | } // namespace prim | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -39,6 +39,7 @@ | |||||
| #include "optimizer/irpass/specialize_transform.h" | #include "optimizer/irpass/specialize_transform.h" | ||||
| #include "optimizer/irpass/incorporate_getitem.h" | #include "optimizer/irpass/incorporate_getitem.h" | ||||
| #include "optimizer/irpass/incorporate_call.h" | #include "optimizer/irpass/incorporate_call.h" | ||||
| #include "optimizer/irpass/grad_var_prepare.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -123,6 +124,11 @@ ResolveIRPassLib::ResolveIRPassLib() { | |||||
| resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve); | resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve); | ||||
| resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr); | resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr); | ||||
| } | } | ||||
| InferenceOptPrepareLib::InferenceOptPrepareLib() { | |||||
| grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode); | |||||
| } | |||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -102,6 +102,13 @@ class ResolveIRPassLib { | |||||
| SubstitutionPtr resolver_getattr_; | SubstitutionPtr resolver_getattr_; | ||||
| }; | }; | ||||
| class InferenceOptPrepareLib { | |||||
| public: | |||||
| InferenceOptPrepareLib(); | |||||
| ~InferenceOptPrepareLib() = default; | |||||
| SubstitutionPtr grad_var_prepare_; | |||||
| }; | |||||
| // predicate functions | // predicate functions | ||||
| inline bool IsNode(const AnfNodePtr &) { return true; } | inline bool IsNode(const AnfNodePtr &) { return true; } | ||||
| @@ -151,6 +158,7 @@ inline bool IsCNodeDup(const AnfNodePtr &node) { | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,144 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "optimizer/irpass/grad_var_prepare.h" | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <unordered_map> | |||||
| #include <memory> | |||||
| #include "operator/composite/composite.h" | |||||
| #include "operator/ops.h" | |||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| #include "ir/visitor.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "ir/func_graph_cloner.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace irpass { | |||||
| static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, FuncGraphPtr func_graph, | |||||
| AnfNodePtr func_node, bool is_unpack, bool sens_param) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(func_node); | |||||
| std::vector<AnfNodePtr> nodes; | |||||
| AnfNodePtr unpack_graph_node = nullptr; | |||||
| if (is_unpack) { | |||||
| auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, true); | |||||
| nodes.push_back(NewValueNode(unpack_graph)); | |||||
| nodes.push_back(func_node); | |||||
| // {unpackcall, {GradOperation, ...}, args...} | |||||
| std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), | |||||
| [](const AnfNodePtr& node) { return node; }); | |||||
| unpack_graph_node = func_graph->NewCNode(nodes); | |||||
| } else { | |||||
| auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, false); | |||||
| nodes.push_back(NewValueNode(unpack_graph)); | |||||
| nodes.push_back(func_node); | |||||
| // {{GradOperation, ...}, args...} | |||||
| std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), | |||||
| [](const AnfNodePtr& node) { return node; }); | |||||
| unpack_graph_node = func_graph->NewCNode(nodes); | |||||
| } | |||||
| return unpack_graph_node; | |||||
| } | |||||
| // get metagraph of value node | |||||
| MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) { | |||||
| ValuePtr value; | |||||
| if (IsValueNode<prim::DoSignaturePrimitive>(node)) { | |||||
| value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function(); | |||||
| } else { | |||||
| value = GetValueNode(node); | |||||
| } | |||||
| if (value == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| return value->cast<MetaFuncGraphPtr>(); | |||||
| } | |||||
| // check if node is a specific metafuncgraph op | |||||
| bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_graph) { | |||||
| if (node != nullptr) { | |||||
| auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node); | |||||
| if (meta_func_graph_ptr == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (meta_func_graph_ptr->type_name() == meta_func_graph->type_name()) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| // {{GradOperation, g, w}, Ys} | |||||
| // {UnPackCall, {GradOperation, g, w}, Ys} | |||||
| AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr&, const AnfNodePtr& node) { | |||||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| // {{...}, Ys} | |||||
| auto inputs_y = node->cast<CNodePtr>()->inputs(); | |||||
| std::vector<AnfNodePtr> inputs_x; | |||||
| if (IsCNode(inputs_y[0])) { | |||||
| inputs_x = inputs_y[0]->cast<CNodePtr>()->inputs(); | |||||
| } else if (IsMetaFuncGraph(inputs_y[0], unpack_op_) && IsCNode(inputs_y[1])) { | |||||
| inputs_x = inputs_y[1]->cast<CNodePtr>()->inputs(); | |||||
| } else { | |||||
| return nullptr; | |||||
| } | |||||
| // {{...}, Xs} | |||||
| if (inputs_x.size() < 2) { | |||||
| return nullptr; | |||||
| } | |||||
| // {GradOperation, g, w} or {GradOperation, g} | |||||
| if (!IsMetaFuncGraph(inputs_x[0], grad_op_)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]); | |||||
| if (meta_func == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto grad_op_ptr = meta_func->cast<prim::GradOperationPtr>(); | |||||
| auto func_node = inputs_x[1]; | |||||
| if (!IsValueNode<FuncGraph>(func_node)) { | |||||
| return nullptr; | |||||
| } | |||||
| AnfNodePtr unpack_graph_node = | |||||
| GenerateUnpackGraphNode(inputs_y, node->cast<CNodePtr>()->func_graph(), func_node, | |||||
| IsMetaFuncGraph(inputs_y[0], unpack_op_), grad_op_ptr->sens_param()); | |||||
| // constuct new grad_opration | |||||
| inputs_x[1] = unpack_graph_node; | |||||
| auto grad_op_cnode = node->func_graph()->NewCNode(inputs_x); | |||||
| if (IsMetaFuncGraph(inputs_y[0], unpack_op_)) { | |||||
| inputs_y[1] = grad_op_cnode; | |||||
| } else { | |||||
| inputs_y[0] = grad_op_cnode; | |||||
| } | |||||
| auto cnode = node->func_graph()->NewCNode(inputs_y); | |||||
| return cnode; | |||||
| } | |||||
| } // namespace irpass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,55 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ | |||||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include <unordered_map> | |||||
| #include <memory> | |||||
| #include "operator/composite/composite.h" | |||||
| #include "operator/ops.h" | |||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| #include "ir/visitor.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "ir/func_graph_cloner.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace irpass { | |||||
| // {{GradOperation, g, w}, Ys} | |||||
| // {UnPackCall, {GradOperation, g, w}, Ys} | |||||
| class GradVarPrepare : public AnfVisitor { | |||||
| public: | |||||
| GradVarPrepare() | |||||
| : grad_op_(std::make_shared<prim::GradOperation>("grad")), | |||||
| unpack_op_(std::make_shared<prim::UnpackCall>("unpack_call")) {} | |||||
| ~GradVarPrepare() override = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||||
| private: | |||||
| MetaFuncGraphPtr grad_op_; | |||||
| MetaFuncGraphPtr unpack_op_; | |||||
| }; | |||||
| } // namespace irpass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ | |||||
| @@ -175,10 +175,10 @@ bool CombineLikeGraphs(const ResourcePtr&) { | |||||
| bool SymbolResolveAction(const ResourcePtr& res) { | bool SymbolResolveAction(const ResourcePtr& res) { | ||||
| if (res->manager() == nullptr) { | if (res->manager() == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Resolve error."; | |||||
| MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null"; | |||||
| } | } | ||||
| if (res->func_graph() == nullptr) { | if (res->func_graph() == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Resolve error"; | |||||
| MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null"; | |||||
| } | } | ||||
| FuncGraphPtr func_graph = res->func_graph(); | FuncGraphPtr func_graph = res->func_graph(); | ||||
| auto succ = parse::ResolveFuncGraph(func_graph, res); | auto succ = parse::ResolveFuncGraph(func_graph, res); | ||||
| @@ -194,6 +194,16 @@ bool SymbolResolveAction(const ResourcePtr& res) { | |||||
| return succ; | return succ; | ||||
| } | } | ||||
| bool InferenceOptPrepareAction(const ResourcePtr& res) { | |||||
| if (res->manager() == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null."; | |||||
| } | |||||
| if (res->func_graph() == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null."; | |||||
| } | |||||
| return InferenceOptPreparePass(res); | |||||
| } | |||||
| bool AbstractSpecializeAction(const ResourcePtr& res) { | bool AbstractSpecializeAction(const ResourcePtr& res) { | ||||
| if (res->func_graph() == nullptr) { | if (res->func_graph() == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "AbstractSpecialize error"; | MS_LOG(EXCEPTION) << "AbstractSpecialize error"; | ||||
| @@ -331,7 +341,7 @@ static std::vector<ActionItem> CommonPipeline() { | |||||
| // Resolve the python func | // Resolve the python func | ||||
| actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); | actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); | ||||
| actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); | actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); | ||||
| actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); | |||||
| // Evaluate type and shape, and specialize | // Evaluate type and shape, and specialize | ||||
| actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); | actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); | ||||
| @@ -160,6 +160,13 @@ OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib& irpass) { | |||||
| return map; | return map; | ||||
| } | } | ||||
| OptPassGroupMap GetInferenceOptPreparePhases() { | |||||
| opt::irpass::InferenceOptPrepareLib irpass; | |||||
| auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_}); | |||||
| opt::OptPassGroupMap prepare_map({{"inference_opt_prep", grad_var_prepare}}); | |||||
| return prepare_map; | |||||
| } | |||||
| OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib& irpass) { | OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib& irpass) { | ||||
| opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_}); | opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_}); | ||||
| OptPassGroupMap map({{"prepare_group", prepare_group}}); | OptPassGroupMap map({{"prepare_group", prepare_group}}); | ||||
| @@ -239,6 +246,16 @@ bool ValidatePass(const ResourcePtr& res) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool InferenceOptPreparePass(const ResourcePtr& res) { | |||||
| FuncGraphPtr func_graph = res->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| abstract::AbstractBasePtrList args_spec = res->args_spec(); | |||||
| auto prepare_map = GetInferenceOptPreparePhases(); | |||||
| auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map); | |||||
| (void)infer_opt_prepare->step(func_graph, args_spec, false); | |||||
| return true; | |||||
| } | |||||
| std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, | ||||
| {"opt_a", OptPassAGroup}, | {"opt_a", OptPassAGroup}, | ||||
| {"opt_b", OptPassBGroup}, | {"opt_b", OptPassBGroup}, | ||||
| @@ -34,7 +34,7 @@ bool CconvPass(const ResourcePtr& res); | |||||
| bool ValidatePass(const ResourcePtr& res); | bool ValidatePass(const ResourcePtr& res); | ||||
| bool ConvertPrepareAdapt(const ResourcePtr& res); | bool ConvertPrepareAdapt(const ResourcePtr& res); | ||||
| bool AddControlDependPass(const ResourcePtr& res); | bool AddControlDependPass(const ResourcePtr& res); | ||||
| bool InferenceOptPreparePass(const ResourcePtr& res); | |||||
| void ReclaimOptimizer(); | void ReclaimOptimizer(); | ||||
| } // namespace pipeline | } // namespace pipeline | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -133,6 +133,7 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { | |||||
| FuncGraphPtr func_graph_; | FuncGraphPtr func_graph_; | ||||
| AnalysisContextPtr context_; | AnalysisContextPtr context_; | ||||
| }; | }; | ||||
| using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>; | |||||
| class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { | class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { | ||||
| public: | public: | ||||
| @@ -41,7 +41,7 @@ AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func | |||||
| } else { | } else { | ||||
| oss << "nullptr"; | oss << "nullptr"; | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "" << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||||
| MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||||
| } | } | ||||
| return NewContext(parent_context, func_graph, args_spec_list); | return NewContext(parent_context, func_graph, args_spec_list); | ||||
| } | } | ||||
| @@ -180,6 +180,85 @@ AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const Config | |||||
| return engine->ForwardConfig(out_conf, fn_conf); | return engine->ForwardConfig(out_conf, fn_conf); | ||||
| } | } | ||||
| static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) { | |||||
| // arg[0] is the func graph to unpack, ignore it | |||||
| AbstractBasePtrList sepcialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end()); | |||||
| AbstractBasePtrList graph_sepcialize_args; | |||||
| if (need_unpack) { | |||||
| for (size_t index = 0; index < sepcialize_args_before_unpack.size(); index++) { | |||||
| MS_EXCEPTION_IF_NULL(sepcialize_args_before_unpack[index]); | |||||
| if (sepcialize_args_before_unpack[index]->isa<AbstractTuple>()) { | |||||
| AbstractTuplePtr arg_tuple = sepcialize_args_before_unpack[index]->cast<AbstractTuplePtr>(); | |||||
| std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(), | |||||
| std::back_inserter(graph_sepcialize_args), [](AbstractBasePtr abs) { return abs; }); | |||||
| } else if (sepcialize_args_before_unpack[index]->isa<AbstractDictionary>()) { | |||||
| AbstractDictionaryPtr arg_dict = sepcialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>(); | |||||
| auto dict_elems = arg_dict->elements(); | |||||
| (void)std::transform( | |||||
| dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_sepcialize_args), | |||||
| [](const AbstractAttribute &item) { return std::make_shared<AbstractKeywordArg>(item.first, item.second); }); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got " | |||||
| << sepcialize_args_before_unpack[index]->ToString(); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| graph_sepcialize_args = sepcialize_args_before_unpack; | |||||
| } | |||||
| return graph_sepcialize_args; | |||||
| } | |||||
| AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||||
| AnfNodeConfigPtr out_conf) { | |||||
| if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; | |||||
| } | |||||
| if (!prim_->isa<prim::UnpackGraphPrimitive>()) { | |||||
| MS_LOG(EXCEPTION) << "Primitive should be UnpackGraphPrimitive, but got " << prim_->ToString(); | |||||
| } | |||||
| auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>(); | |||||
| auto out_node = out_conf->node()->cast<CNodePtr>(); | |||||
| const auto &out_node_inputs = out_node->inputs(); | |||||
| if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { | |||||
| MS_LOG(EXCEPTION) << "UnpackGraphPrimitive" | |||||
| << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() | |||||
| << ", inputs size " << out_node_inputs.size(); | |||||
| } | |||||
| AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; | |||||
| AbstractBasePtrList args_spec_list; | |||||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||||
| [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); }); | |||||
| // get the forward graph | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||||
| AbstractFunctionPtr fn = args_spec_list[0]->cast<AbstractFunctionPtr>(); | |||||
| if (fn == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString(); | |||||
| } | |||||
| auto real_fn = fn->cast<FuncGraphAbstractClosurePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(real_fn); | |||||
| FuncGraphPtr forward_graph = real_fn->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(forward_graph); | |||||
| AbstractBasePtrList graph_sepcialize_args = | |||||
| GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args()); | |||||
| AbstractBasePtrList graph_sepcialize_args_without_sens; | |||||
| (void)std::transform(graph_sepcialize_args.begin(), | |||||
| graph_sepcialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0), | |||||
| std::back_inserter(graph_sepcialize_args_without_sens), [](AbstractBasePtr abs) { return abs; }); | |||||
| auto new_graph = forward_graph->GenerateGraph(graph_sepcialize_args_without_sens); | |||||
| engine->func_graph_manager()->AddFuncGraph(new_graph); | |||||
| ScopePtr scope = kDefaultScope; | |||||
| if (out_conf != nullptr) { | |||||
| scope = out_conf->node()->scope(); | |||||
| } | |||||
| ScopeGuard scope_guard(scope); | |||||
| AnfNodePtr new_vnode = NewValueNode(new_graph); | |||||
| AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context()); | |||||
| return engine->ForwardConfig(out_conf, fn_conf); | |||||
| } | |||||
| namespace { | namespace { | ||||
| py::object BuildValue(const ValuePtr &value_ptr) { | py::object BuildValue(const ValuePtr &value_ptr) { | ||||
| if (value_ptr == nullptr) { | if (value_ptr == nullptr) { | ||||
| @@ -87,6 +87,21 @@ class DoSignatureEvaluator : public Evaluator { | |||||
| PrimitivePtr prim_; | PrimitivePtr prim_; | ||||
| }; | }; | ||||
| class UnpackGraphEvaluator : public Evaluator { | |||||
| public: | |||||
| explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {} | |||||
| ~UnpackGraphEvaluator() override = default; | |||||
| AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, | |||||
| AnfNodeConfigPtr out_config = nullptr) override; | |||||
| AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { | |||||
| MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called"; | |||||
| } | |||||
| private: | |||||
| PrimitivePtr prim_; | |||||
| }; | |||||
| bool IsInWhiteList(PrimitivePtr primitive); | bool IsInWhiteList(PrimitivePtr primitive); | ||||
| StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); | StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); | ||||
| @@ -289,6 +289,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr | |||||
| evaluator = std::make_shared<DoSignatureEvaluator>(prim); | evaluator = std::make_shared<DoSignatureEvaluator>(prim); | ||||
| return evaluator; | return evaluator; | ||||
| } | } | ||||
| if (prim->isa<prim::UnpackGraphPrimitive>()) { | |||||
| evaluator = std::make_shared<UnpackGraphEvaluator>(prim); | |||||
| return evaluator; | |||||
| } | |||||
| if (prim->HasPyEvaluator()) { | if (prim->HasPyEvaluator()) { | ||||
| auto prim_py = dyn_cast<PrimitivePy>(prim); | auto prim_py = dyn_cast<PrimitivePy>(prim); | ||||
| if (prim_py != nullptr) { | if (prim_py != nullptr) { | ||||
| @@ -19,6 +19,8 @@ from mindspore.nn import Cell | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| import mindspore.ops.composite as C | import mindspore.ops.composite as C | ||||
| from mindspore.common.api import _executor | from mindspore.common.api import _executor | ||||
| from mindspore.common.parameter import ParameterTuple | |||||
| from mindspore.common import dtype as mstype | |||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||
| @@ -34,3 +36,152 @@ def test_net_vargs_expand(): | |||||
| sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32)) | sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32)) | ||||
| net = AddNet() | net = AddNet() | ||||
| out = C.grad_all_with_sens(net, net.trainable_params())(x, y, sens) | out = C.grad_all_with_sens(net, net.trainable_params())(x, y, sens) | ||||
| class VarNet(Cell): | |||||
| def __init__(self, net): | |||||
| super(VarNet, self).__init__() | |||||
| self.b = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True) | |||||
| self.w = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "w", requires_grad=True) | |||||
| self.net = net | |||||
| def construct(self, *args): | |||||
| return self.net(*args)*self.w + self.b | |||||
| class SecondNet(Cell): | |||||
| def __init__(self): | |||||
| super(SecondNet, self).__init__() | |||||
| self.b2 = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True) | |||||
| def construct(self, *args): | |||||
| res = args[0] + args[1] | |||||
| return res + self.b2 | |||||
| def test_all_var_args_grad_with_sens(): | |||||
| """"test grad_by_list_with_sens with all var args input""" | |||||
| class GradNet(Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.weights = ParameterTuple(net.trainable_params()) | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return C.grad_by_list_with_sens(self.net, self.weights)(*inputs) | |||||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||||
| sens = Tensor(1.0, dtype=mstype.float32) | |||||
| net = VarNet(SecondNet()) | |||||
| grad_net = GradNet(net) | |||||
| out = grad_net(x, y, sens) | |||||
| def test_grad_list_var_args(): | |||||
| class GradNet(Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.weights = ParameterTuple(net.trainable_params()) | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return C.grad_by_list(self.net, self.weights)(*inputs) | |||||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||||
| net = VarNet(SecondNet()) | |||||
| grad_net = GradNet(net) | |||||
| out = grad_net(x, y) | |||||
| def test_grad_all_var_args(): | |||||
| class GradNet(Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.weights = ParameterTuple(net.trainable_params()) | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return C.grad_all(self.net)(*inputs) | |||||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||||
| net = VarNet(SecondNet()) | |||||
| grad_net = GradNet(net) | |||||
| out = grad_net(x, y) | |||||
| def test_grad_all_var_args_with_sens(): | |||||
| class GradNet(Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.weights = ParameterTuple(net.trainable_params()) | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return C.grad_all_with_sens(self.net)(*inputs) | |||||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||||
| sens = Tensor(1.0, dtype=mstype.float32) | |||||
| net = VarNet(SecondNet()) | |||||
| grad_net = GradNet(net) | |||||
| out = grad_net(x, y, sens) | |||||
| def test_grad_var_args_with_sens(): | |||||
| class GradNet(Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.weights = ParameterTuple(net.trainable_params()) | |||||
| self.net = net | |||||
| def construct(self, *inputs): | |||||
| return C.grad_with_sens(self.net)(*inputs) | |||||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||||
| sens = Tensor(1.0, dtype=mstype.float32) | |||||
| net = VarNet(SecondNet()) | |||||
| grad_net = GradNet(net) | |||||
| out = grad_net(x, y, sens) | |||||
| def test_var_args_grad(): | |||||
| class VarNet(Cell): | |||||
| def __init__(self, net): | |||||
| super(VarNet, self).__init__() | |||||
| self.b = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True) | |||||
| self.net = net | |||||
| def construct(self, *args): | |||||
| return self.net(*args) + self.b | |||||
| class SecondNet(Cell): | |||||
| def __init__(self): | |||||
| super(SecondNet, self).__init__() | |||||
| self.b2 = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True) | |||||
| def construct(self, *args): | |||||
| res = args[0] + args[1] | |||||
| return res + self.b2 | |||||
| class GradNet(Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| self.weights = ParameterTuple(net.trainable_params()) | |||||
| def construct(self, x, y, sens): | |||||
| return C.grad_by_list_with_sens(self.net, self.weights)(x, y, sens) | |||||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||||
| sens = Tensor(1.0, dtype=mstype.float32) | |||||
| net = VarNet(SecondNet()) | |||||
| grad_net = GradNet(net) | |||||
| out = grad_net(x, y, sens) | |||||
| def test_var_args_positional(): | |||||
| """"test grad_all with var args in inner graph""" | |||||
| class VarNet(Cell): | |||||
| def __init__(self, net): | |||||
| super(VarNet, self).__init__() | |||||
| self.net = net | |||||
| def construct(self, x, y): | |||||
| return self.net(x, y)*x | |||||
| class SecondNet(Cell): | |||||
| def __init__(self): | |||||
| super(SecondNet, self).__init__() | |||||
| def construct(self, *args): | |||||
| return args[0] + args[1] | |||||
| class GradNet(Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| self.weights = ParameterTuple(net.trainable_params()) | |||||
| def construct(self, x, y): | |||||
| return C.grad_all(self.net)(x, y) | |||||
| x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||||
| y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) | |||||
| net = VarNet(SecondNet()) | |||||
| grad_net = GradNet(net) | |||||
| out = grad_net(x, y) | |||||