| @@ -944,6 +944,80 @@ REGISTER_PYBIND_DEFINE(VmapOperation_, ([](const py::module *m) { | |||
| .def(py::init<std::string &>(), py::arg("fn")); | |||
| })); | |||
| TaylorOperation::TaylorOperation(const std::string &name) : MetaFuncGraph(name) { | |||
| // def Taylor(func:read): | |||
| signatures_ = std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}}); | |||
| } | |||
| FuncGraphPtr TaylorOperation::GetTaylorGrad(const AnfNodePtr &k, const std::vector<AnfNodePtr> &forward_graph_params) { | |||
| FuncGraphPtr k_child = std::make_shared<FuncGraph>(); | |||
| k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.push_back(k); | |||
| MS_LOG(INFO) << "TaylorOperation forward input size " << forward_graph_params.size(); | |||
| for (size_t i = 0; i < forward_graph_params.size(); ++i) { | |||
| inputs.push_back(k_child->add_parameter()); | |||
| } | |||
| // Taylor(fn)(input params) | |||
| auto k_app = k_child->NewCNodeInOrder(inputs); | |||
| k_child->set_output(k_app); | |||
| return k_child; | |||
| } | |||
| // Generate the graph to calculate higher order derivatives. | |||
| FuncGraphPtr TaylorOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| if (args_spec_list.empty()) { | |||
| MS_LOG(EXCEPTION) | |||
| << "'TaylorOperation' requires a forward network or function as an input, while the input is empty."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_spec_list[0]); | |||
| if (fn == nullptr) { | |||
| MS_LOG(EXCEPTION) << "'TaylorOperation' arg0 must be a 'Function' or 'Cell', but got " | |||
| << args_spec_list[0]->ToString(); | |||
| } | |||
| auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn); | |||
| MS_EXCEPTION_IF_NULL(real_fn); | |||
| FuncGraphPtr forward_graph = real_fn->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(forward_graph); | |||
| forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true); | |||
| FuncGraphPtr grad_fg = nullptr; | |||
| MS_LOG(INFO) << "'TaylorOperation' forward_graph" << forward_graph->debug_info(); | |||
| grad_fg = std::make_shared<FuncGraph>(); | |||
| auto nparam = forward_graph->parameters().size(); | |||
| std::ostringstream ss; | |||
| ss << "taylorgrad{" << nparam << "}"; | |||
| grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||
| grad_fg->debug_info()->set_name(ss.str()); | |||
| ParameterPtr param_graph = grad_fg->add_parameter(); | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.push_back(NewValueNode(prim::kPrimTaylor)); | |||
| inputs.push_back(param_graph); | |||
| // Taylor(fn) | |||
| auto mark_taylor = grad_fg->NewCNodeInOrder(inputs); | |||
| FuncGraphPtr k_child = nullptr; | |||
| { | |||
| TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info())); | |||
| k_child = GetTaylorGrad(mark_taylor, forward_graph->parameters()); | |||
| } | |||
| grad_fg->set_output(NewValueNode(k_child)); | |||
| // return Taylor(fn)(inputs) | |||
| return grad_fg; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(TaylorOperation_, ([](const py::module *m) { | |||
| (void)py::class_<TaylorOperation, MetaFuncGraph, std::shared_ptr<TaylorOperation>>( | |||
| *m, "TaylorOperation_") | |||
| .def(py::init<std::string &>(), py::arg("fn")); | |||
| })); | |||
| // Generate the ListMap func graph. | |||
| FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| size_t args_num = args_spec_list.size(); | |||
| @@ -166,6 +166,17 @@ class GradOperation : public MetaFuncGraph { | |||
| }; | |||
| using GradOperationPtr = std::shared_ptr<GradOperation>; | |||
| class TaylorOperation : public MetaFuncGraph { | |||
| public: | |||
| explicit TaylorOperation(const std::string &name); | |||
| ~TaylorOperation() override = default; | |||
| MS_DECLARE_PARENT(TaylorOperation, MetaFuncGraph); | |||
| FuncGraphPtr GetTaylorGrad(const AnfNodePtr &k, const std::vector<AnfNodePtr> &forward_graph_params); | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||
| }; | |||
| using TaylorOperationPtr = std::shared_ptr<TaylorOperation>; | |||
| class ListMap { | |||
| public: | |||
| explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); } | |||
| @@ -738,6 +738,25 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit | |||
| return AbstractFunction::MakeAbstractFunction(jv); | |||
| } | |||
| AbstractBasePtr InferImplTaylor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // args: An object of AbstractFunction. | |||
| CheckArgsSize(primitive->name(), args_spec_list, 1); | |||
| MS_LOG(DEBUG) << "evaluate Taylor: " << args_spec_list[0]->ToString(); | |||
| AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_spec_list[0]); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| AbstractFuncAtomPtrList taylor_v; | |||
| auto build_taylor_v = [&taylor_v](const AbstractFuncAtomPtr &func) { | |||
| auto taylor_closure = std::make_shared<TaylorTransformedAbstractClosure>(func); | |||
| taylor_v.push_back(taylor_closure); | |||
| }; | |||
| x->Visit(build_taylor_v); | |||
| return AbstractFunction::MakeAbstractFunction(taylor_v); | |||
| } | |||
| AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: func, in_axes, out_axes, device, level. | |||
| @@ -846,6 +865,7 @@ REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferI | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Taylor, prim::kPrimTaylor, InferImplTaylor, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Shard, prim::kPrimShard, InferImplShard, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(Vmap, prim::kPrimVmap, InferImplVmap, nullptr); | |||
| REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, | |||
| @@ -55,6 +55,8 @@ AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTaylor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplVmap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -21,6 +21,7 @@ | |||
| #include "frontend/optimizer/irpass/convert.h" | |||
| #include "frontend/optimizer/irpass/environ_eliminate.h" | |||
| #include "frontend/optimizer/irpass/grad_var_prepare.h" | |||
| #include "frontend/optimizer/irpass/taylor_eliminate.h" | |||
| #include "frontend/optimizer/irpass/inline.h" | |||
| #include "frontend/optimizer/irpass/updatestate_eliminate.h" | |||
| #include "frontend/optimizer/irpass/load_eliminate.h" | |||
| @@ -22,6 +22,7 @@ | |||
| #include "base/core_ops.h" | |||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | |||
| #include "frontend/optimizer/irpass/vmap_eliminate.h" | |||
| #include "frontend/optimizer/irpass/taylor_eliminate.h" | |||
| #include "frontend/optimizer/irpass/meta_fg_prim_eliminate.h" | |||
| namespace mindspore { | |||
| @@ -35,8 +36,8 @@ class ExpandMetaFg { | |||
| // to the implementation of `kPrimVmap`. | |||
| (void)expand_meta_fg_list_.emplace_back(std::make_shared<ExpandJPrim>()); | |||
| (void)expand_meta_fg_list_.emplace_back(std::make_shared<ExpandVmapPrim>()); | |||
| (void)expand_meta_fg_list_.emplace_back(std::make_shared<ExpandTaylorPrim>()); | |||
| } | |||
| virtual ~ExpandMetaFg() = default; | |||
| bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer); | |||
| @@ -0,0 +1,158 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "frontend/optimizer/irpass/taylor_eliminate.h" | |||
| #include "pipeline/pynative/pynative_execute.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| namespace internal { | |||
| // White list of ops with taylor rule. | |||
| mindspore::HashSet<std::string> taylor_ops{prim::kPrimAdd->name(), prim::kPrimSub->name(), prim::kPrimRealDiv->name(), | |||
| prim::kPrimMul->name(), prim::kPrimSin->name(), prim::kPrimCos->name(), | |||
| prim::kPrimExp->name()}; | |||
| // The ops below are excluded when considering taylor rules. | |||
| mindspore::HashSet<std::string> taylor_exception_ops{prim::kPrimReturn->name(), prim::kPrimMakeTuple->name(), | |||
| prim::kPrimTupleGetItem->name(), prim::kPrimCast->name()}; | |||
| // Cache list of primitive ops which have been replaced by taylor rule. | |||
| mindspore::HashMap<PrimitivePtr, FuncGraphPtr> taylor_ops_cache_; | |||
| FuncGraphPtr GetTaylorRule(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) { | |||
| // Set a child scope named "grad'PrimitiveName'" for the taylor rule function, | |||
| // and add "Gradients" to the front. | |||
| static const std::string gradients_scope = "Gradients/"; | |||
| static const std::string grad_op_child_scope_prefix = "/grad"; | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() + | |||
| grad_op_child_scope_prefix + prim->name()); | |||
| ScopeGuard scope_guard(scope); | |||
| // Firstly we get taylor rule from mindir. If failed, parse the python function registered. | |||
| FuncGraphPtr func_graph = nullptr; | |||
| py::function taylor_fn; | |||
| if (prim->is_base()) { | |||
| taylor_fn = GetTaylorRuleFunction(prim->name()); | |||
| } else { | |||
| taylor_fn = prim->cast<PrimitivePyPtr>()->GetTaylorRuleFunction(); | |||
| if (py::isinstance<py::none>(taylor_fn)) { | |||
| taylor_fn = GetTaylorRuleFunction(prim->name()); | |||
| } | |||
| } | |||
| if (!taylor_fn || py::isinstance<py::none>(taylor_fn)) { | |||
| MS_LOG(INFO) << "Fail to find taylor rule function for " << prim->name() << ". taylor_fn: " << py::str(taylor_fn); | |||
| return nullptr; | |||
| } | |||
| func_graph = parse::ParsePythonCode(taylor_fn); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Fail to parse taylor rule function for " << prim->name() << "."; | |||
| return nullptr; | |||
| } | |||
| auto taylor_rule_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP); | |||
| if (taylor_rule_flag) { | |||
| func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true); | |||
| } | |||
| pipeline::ResourceBasePtr res = (resources != nullptr) ? resources : std::make_shared<pipeline::Resource>(); | |||
| (void)parse::ResolveFuncGraph(func_graph, res); | |||
| return func_graph; | |||
| } | |||
| FuncGraphPtr GetTaylorPyObj(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) { | |||
| auto fg = GetTaylorRule(prim, resources); | |||
| return fg; | |||
| } | |||
| FuncGraphPtr GetTaylorPrimitive(const AnfNodePtr &node, const pipeline::ResourceBasePtr &resources) { | |||
| auto prim_node = GetValueNode<PrimitivePtr>(node); | |||
| MS_EXCEPTION_IF_NULL(prim_node); | |||
| auto iter = taylor_ops_cache_.find(prim_node); | |||
| if (iter != taylor_ops_cache_.end()) { | |||
| return iter->second; | |||
| } | |||
| FuncGraphPtr primitive_taylor = GetTaylorPyObj(prim_node, resources); | |||
| MS_EXCEPTION_IF_NULL(primitive_taylor); | |||
| taylor_ops_cache_[prim_node] = primitive_taylor; | |||
| return primitive_taylor; | |||
| } | |||
| FuncGraphPtr TaylorFunctor(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources) { | |||
| const auto &value_nodes = func_graph->value_nodes(); | |||
| auto manager = resources->manager(); | |||
| manager->AddFuncGraph(func_graph); | |||
| std::vector<AnfNodePtr> taylor_node_list; | |||
| for (const auto &value_pair : value_nodes) { | |||
| auto node = value_pair.first; | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (IsValueNode<Primitive>(node)) { | |||
| auto prim_node = GetValueNode<PrimitivePtr>(node); | |||
| if (taylor_ops.count(prim_node->name())) { | |||
| taylor_node_list.push_back(node); | |||
| } else if (!taylor_exception_ops.count(prim_node->name())) { | |||
| MS_LOG(EXCEPTION) << "The operation " << prim_node->name() | |||
| << " is not supported in taylor higher order differentiation currently."; | |||
| } | |||
| } | |||
| } | |||
| for (size_t i = 0; i < taylor_node_list.size(); i++) { | |||
| FuncGraphPtr taylor_node_graph = GetTaylorPrimitive(taylor_node_list[i], resources); | |||
| manager->Replace(taylor_node_list[i], NewValueNode(taylor_node_graph)); | |||
| } | |||
| taylor_ops_cache_.clear(); | |||
| MS_LOG(INFO) << "return replaced taylor node: " << func_graph->ToString() << " replace end."; | |||
| return func_graph; | |||
| } | |||
| AnfNodePtr ExpandTaylor(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { | |||
| if (IsValueNode<FuncGraph>(vnode)) { | |||
| ScopeGuard scope_guard(vnode->scope()); | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(vnode); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandTaylor now"; | |||
| auto newfg = TaylorFunctor(func_graph, resource); | |||
| return NewValueNode(newfg); | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace internal | |||
| bool ExpandTaylorPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) { | |||
| // Search all taylor nodes. | |||
| bool change = false; | |||
| auto manager = optimizer->manager(); | |||
| for (auto &taylor_node : prim_nodes_) { | |||
| auto taylor_fg_node = taylor_node->input(1); | |||
| auto taylor_fg = GetValueNode<FuncGraphPtr>(taylor_fg_node); | |||
| if (taylor_fg == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Unexpected Taylor node, input func graph should not be null, node: " | |||
| << taylor_fg->ToString(); | |||
| } | |||
| // Copy original forward graph in case of the influence of usage in other place. | |||
| auto taylor_fg_copy = BasicClone(taylor_fg, true); | |||
| manager->AddFuncGraph(taylor_fg_copy); | |||
| auto taylor_fg_copy_node = NewValueNode(taylor_fg_copy); | |||
| // Return expanded taylor graph. | |||
| auto expanded_taylor = internal::ExpandTaylor(taylor_fg_copy_node->cast<ValueNodePtr>(), optimizer->resource()); | |||
| manager->Replace(taylor_node, expanded_taylor); | |||
| change = true; | |||
| } | |||
| return change; | |||
| } | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TAYLOR_ELIMINATE_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TAYLOR_ELIMINATE_H_ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "frontend/optimizer/anf_visitor.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "include/common/utils/primitive_utils.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/optimizer/ad/grad.h" | |||
| #include "frontend/optimizer/irpass/meta_fg_prim_eliminate.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| // {prim::kPrimTaylor, C} | |||
| class ExpandTaylorPrim : public ExpandMetaFGPrim { | |||
| public: | |||
| ExpandTaylorPrim() { prim_ = prim::kPrimTaylor; } | |||
| virtual ~ExpandTaylorPrim() = default; | |||
| bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer); | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TAYLOR_ELIMINATE_H_ | |||
| @@ -353,6 +353,7 @@ constexpr char SIMPLE_MEAN[] = "SimpleMean"; | |||
| constexpr char FLATTEN[] = "Flatten"; | |||
| constexpr char J[] = "J"; | |||
| constexpr char SHARD[] = "Shard"; | |||
| constexpr char Taylor[] = "Taylor"; | |||
| constexpr char TMPIDENTITY_INFO_NAME[] = "identity_info"; | |||
| constexpr char COS[] = "Cos"; | |||
| constexpr char ACOS[] = "ACos"; | |||
| @@ -30,6 +30,10 @@ COMMON_EXPORT py::function GetBpropFunctionByObj(const py::object &obj); | |||
| COMMON_EXPORT py::function GetBpropFunction(const std::string &name); | |||
| COMMON_EXPORT py::function GetTaylorRuleFunctionByObj(const py::object &obj); | |||
| COMMON_EXPORT py::function GetTaylorRuleFunction(const std::string &name); | |||
| COMMON_EXPORT py::function GetComputeFunction(const std::string &name); | |||
| COMMON_EXPORT BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args); | |||
| @@ -52,6 +52,7 @@ | |||
| #include "frontend/optimizer/irpass/ge_specialized_prepare.h" | |||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | |||
| #include "frontend/optimizer/irpass/shard_eliminate.h" | |||
| #include "frontend/optimizer/irpass/taylor_eliminate.h" | |||
| #include "frontend/optimizer/irpass/parameter_eliminate.h" | |||
| #include "frontend/optimizer/irpass/updatestate_eliminate.h" | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| @@ -572,6 +572,27 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg | |||
| return res; | |||
| } | |||
| EvalResultPtr TaylorEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| const AnfNodeConfigPtr &) { | |||
| AbstractBasePtrList args_spec_list; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| return conf->ObtainEvalResult()->abstract(); | |||
| }); | |||
| MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_); | |||
| auto eval_result = evaluator_cache_mgr_->GetValue(args_spec_list); | |||
| if (eval_result != nullptr) { | |||
| return eval_result; | |||
| } | |||
| // Call the original evaluator, get the result: y = f(x) | |||
| EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr); | |||
| MS_EXCEPTION_IF_NULL(result); | |||
| evaluator_cache_mgr_->SetValue(args_spec_list, result); | |||
| return result; | |||
| } | |||
| EvalResultPtr ShardEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| const AnfNodeConfigPtr &) { | |||
| AbstractBasePtrList args_spec_list; | |||
| @@ -372,6 +372,38 @@ class JEvaluator : public Evaluator { | |||
| AbstractFunctionPtr orig_func_; | |||
| }; | |||
| class TaylorEvaluator : public Evaluator { | |||
| public: | |||
| TaylorEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func) | |||
| : Evaluator("TaylorEvaluator"), evaluator_(evaluator), orig_func_(orig_func) {} | |||
| ~TaylorEvaluator() override = default; | |||
| MS_DECLARE_PARENT(TaylorEvaluator, Evaluator); | |||
| AnfNodePtr bound_node() const override { | |||
| if (evaluator_ != nullptr) { | |||
| return evaluator_->bound_node(); | |||
| } | |||
| return bound_node_.lock(); | |||
| } | |||
| void set_bound_node(const AnfNodePtr &node) override { | |||
| if (evaluator_ != nullptr) { | |||
| evaluator_->set_bound_node(node); | |||
| } | |||
| bound_node_ = AnfNodeWeakPtr(node); | |||
| } | |||
| EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override { | |||
| MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; | |||
| } | |||
| EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| const AnfNodeConfigPtr &out_conf) override; | |||
| std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } | |||
| private: | |||
| EvaluatorPtr evaluator_; | |||
| AbstractFunctionPtr orig_func_; | |||
| }; | |||
| class ShardEvaluator : public Evaluator { | |||
| public: | |||
| ShardEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func) | |||
| @@ -501,6 +501,14 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<VmapTransfor | |||
| return vmap_evaluator; | |||
| } | |||
| EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TaylorTransformedAbstractClosure> &func) { | |||
| MS_EXCEPTION_IF_NULL(func); | |||
| AbstractFunctionPtr func_orig = func->fn(); | |||
| EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig); | |||
| auto taylorevaluator = std::make_shared<TaylorEvaluator>(evaluator_orig, func_orig); | |||
| return taylorevaluator; | |||
| } | |||
| EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &func) { | |||
| MS_EXCEPTION_IF_NULL(func); | |||
| AbstractFunctionPtr func_orig = func->fn(); | |||
| @@ -548,6 +556,8 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) { | |||
| return _GetEvaluatorFor(func->cast<std::shared_ptr<JTransformedAbstractClosure>>()); | |||
| } else if (func->isa<VmapTransformedAbstractClosure>()) { | |||
| return _GetEvaluatorFor(func->cast<std::shared_ptr<VmapTransformedAbstractClosure>>()); | |||
| } else if (func->isa<TaylorTransformedAbstractClosure>()) { | |||
| return _GetEvaluatorFor(func->cast<std::shared_ptr<TaylorTransformedAbstractClosure>>()); | |||
| } else if (func->isa<ShardTransformedAbstractClosure>()) { | |||
| return _GetEvaluatorFor(func->cast<std::shared_ptr<ShardTransformedAbstractClosure>>()); | |||
| } else if (func->isa<VirtualAbstractClosure>()) { | |||
| @@ -324,6 +324,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VirtualAbstractClosure> &fn); | |||
| EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TypedPrimitiveAbstractClosure> &); | |||
| EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<JTransformedAbstractClosure> &fn); | |||
| EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<TaylorTransformedAbstractClosure> &fn); | |||
| EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<ShardTransformedAbstractClosure> &fn); | |||
| EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr<VmapTransformedAbstractClosure> &fn); | |||
| @@ -151,6 +151,18 @@ py::function PrimitivePy::GetBpropFunction() { | |||
| auto fn = GetBpropFunctionByObj(python_obj_); | |||
| return fn; | |||
| } | |||
| auto fn = GetBpropFunctionByObj(python_obj_); | |||
| return fn; | |||
| } | |||
| py::function PrimitivePy::GetTaylorRuleFunction() { | |||
| static const char *const get_taylor_rule_func_name = "get_taylor_rule"; | |||
| if (py::hasattr(python_obj_, get_taylor_rule_func_name)) { | |||
| py::function fn = python_obj_.attr(get_taylor_rule_func_name)().cast<py::function>(); | |||
| return fn; | |||
| } | |||
| auto fn = GetTaylorRuleFunctionByObj(python_obj_); | |||
| return fn; | |||
| } | |||
| py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args, const std::string &bprop_cls_name) { | |||
| @@ -50,6 +50,7 @@ class PrimitivePy : public Primitive { | |||
| const bool parse_info_ = true; | |||
| py::function GetVmapRuleFunction(const bool is_side_effect = false, int axis_size = 0); | |||
| py::function GetBpropFunction(); | |||
| py::function GetTaylorRuleFunction(); | |||
| void set_signatures(const std::vector<Signature> &signatures); | |||
| const std::vector<Signature> &signatures() const { return signatures_; } | |||
| const std::map<int, py::function> &backward_hook_fn() const { return backward_hook_fn_; } | |||
| @@ -41,6 +41,18 @@ py::function GetBpropFunction(const std::string &name) { | |||
| return fn; | |||
| } | |||
| py::function GetTaylorRuleFunctionByObj(const py::object &obj) { | |||
| static const std::string get_taylor_fprop_fn = "get_taylor_fprop_fn"; | |||
| static const std::string ad_module = "mindspore.ops._grad"; | |||
| py::function fn = python_adapter::GetPyFn(ad_module, get_taylor_fprop_fn)(obj); | |||
| return fn; | |||
| } | |||
| py::function GetTaylorRuleFunction(const std::string &name) { | |||
| auto fn = GetTaylorRuleFunctionByObj(py::str(name)); | |||
| return fn; | |||
| } | |||
| py::function GetComputeFunction(const std::string &name) { | |||
| static const std::string module = "mindspore._extends.builtin_operations"; | |||
| py::module mod = py::module::import(common::SafeCStr(module)); | |||
| @@ -297,6 +297,20 @@ std::size_t JTransformedAbstractClosure::hash() const { | |||
| return hash_value; | |||
| } | |||
| bool TaylorTransformedAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| if (!other.isa<TaylorTransformedAbstractClosure>()) { | |||
| return false; | |||
| } | |||
| auto other_transformed = static_cast<const TaylorTransformedAbstractClosure *>(&other); | |||
| return fn_ == other_transformed->fn_; | |||
| } | |||
| std::size_t TaylorTransformedAbstractClosure::hash() const { | |||
| MS_EXCEPTION_IF_NULL(fn_); | |||
| auto hash_value = hash_combine(tid(), fn_->hash()); | |||
| return hash_value; | |||
| } | |||
| bool ShardTransformedAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| if (!other.isa<ShardTransformedAbstractClosure>()) { | |||
| return false; | |||
| @@ -339,6 +339,36 @@ class MS_CORE_API JTransformedAbstractClosure final : public AbstractFuncAtom { | |||
| AbstractFuncAtomPtr fn_; | |||
| }; | |||
| /// \brief TaylorTransformedAbstractClosure defines interface for abstract of Function | |||
| /// transformed through the application of Taylor. | |||
| class MS_CORE_API TaylorTransformedAbstractClosure final : public AbstractFuncAtom { | |||
| public: | |||
| /// \brief Constructor of TaylorTransformedAbstractClosure | |||
| /// | |||
| /// \param[in] fn The AbstractFuncAtom transformed through the application of Taylor. | |||
| explicit TaylorTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {} | |||
| /// \brief Destructor of TaylorTransformedAbstractClosure | |||
| ~TaylorTransformedAbstractClosure() override = default; | |||
| MS_DECLARE_PARENT(TaylorTransformedAbstractClosure, AbstractFuncAtom) | |||
| /// \brief Get the AbstractFuncAtom TaylorTransformedAbstractClosure corresponding to. | |||
| /// | |||
| /// \return The AbstractFuncAtom TaylorTransformedAbstractClosure corresponding to. | |||
| AbstractFuncAtomPtr fn() { return fn_; } | |||
| AbstractFunctionPtr Copy() const override { return std::make_shared<TaylorTransformedAbstractClosure>(fn_); } | |||
| bool operator==(const AbstractFunction &other) const override; | |||
| std::size_t hash() const override; | |||
| std::string ToString() const override { return "Taylor(" + fn_->ToString() + ")"; } | |||
| private: | |||
| AbstractFuncAtomPtr fn_; | |||
| }; | |||
| /// \brief ShardTransformedAbstractClosure defines interface for abstract of Function | |||
| /// transformed through the application of Shard. | |||
| class MS_CORE_API ShardTransformedAbstractClosure final : public AbstractFuncAtom { | |||
| @@ -171,6 +171,7 @@ constexpr auto kCSRDiv = "CSRDiv"; | |||
| // Meta Function Graph | |||
| constexpr auto kJ = "J"; | |||
| constexpr auto kVmap = "Vmap"; | |||
| constexpr auto kTaylor = "Taylor"; | |||
| // Others | |||
| constexpr auto kMakeTuple = "MakeTuple"; | |||
| @@ -828,6 +829,7 @@ GVAR_DEF(PrimitivePtr, kPrimStateSetItem, std::make_shared<Primitive>("state_set | |||
| GVAR_DEF(PrimitivePtr, kPrimJ, std::make_shared<Primitive>(kJ, kSideEffectPropagate)); | |||
| GVAR_DEF(PrimitivePtr, kPrimVmap, std::make_shared<Primitive>(kVmap, kSideEffectPropagate)); | |||
| GVAR_DEF(PrimitivePtr, kPrimShard, std::make_shared<Primitive>("Shard", kSideEffectPropagate)); | |||
| GVAR_DEF(PrimitivePtr, kPrimTaylor, std::make_shared<Primitive>(kTaylor)); | |||
| // Used to build graph which have keyword arguments | |||
| GVAR_DEF(PrimitivePtr, kPrimExtractKeywordArg, std::make_shared<Primitive>("extract_keyword_arg")); | |||
| @@ -313,7 +313,7 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(cons | |||
| } | |||
| } | |||
| // Check if the function graph embed with `MetaFGPrim`, which currently covers kPrimJ and kPrimVmap. | |||
| // Check if the function graph embed with `MetaFGPrim`, which currently covers kPrimJ and kPrimVmap and kPrimTaylor. | |||
| bool FuncGraphManager::func_graph_meta_fg_prim_total(const FuncGraphPtr &fg) const { | |||
| MS_EXCEPTION_IF_NULL(meta_fg_prim_total_); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| @@ -704,7 +704,8 @@ void FuncGraphManager::OnEdgeAdded(const AnfNodePtr &node, int index, const AnfN | |||
| signals_->InvalidateComputer(); | |||
| } | |||
| } | |||
| if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap)) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap) || | |||
| IsPrimitiveCNode(node, prim::kPrimTaylor)) { | |||
| fg->AddMetaFgPrimValueNode(input); | |||
| } | |||
| } else if (fg != nullptr && fg != input->func_graph()) { | |||
| @@ -725,7 +726,8 @@ void FuncGraphManager::OnEdgeRemoved(const AnfNodePtr &node, int index, const An | |||
| signals_->InvalidateComputer(); | |||
| } | |||
| } | |||
| if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap)) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap) || | |||
| IsPrimitiveCNode(node, prim::kPrimTaylor)) { | |||
| fg->DropMetaFgPrimValueNode(input); | |||
| } | |||
| } else if (fg != nullptr && fg != input->func_graph()) { | |||
| @@ -1096,7 +1098,7 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se | |||
| return false; | |||
| } | |||
| // Check MetaFgPrim (J/Vmap) FuncGraph input. | |||
| // Check MetaFgPrim (J/Vmap/Taylor) FuncGraph input. | |||
| const auto &meta_fg_prim_values = fg->meta_fg_prim_value_nodes(); | |||
| if (!meta_fg_prim_values.empty()) { | |||
| auto contains_meta_fg_prim = | |||
| @@ -1107,9 +1109,10 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se | |||
| return func_graph->seen_ != seen_num; | |||
| } | |||
| if (IsValueNode<Primitive>(iter.first)) { | |||
| // Exclude the primitive of MetaFgPrim (J/Vmap) itself. | |||
| // Exclude the primitive of MetaFgPrim (J/Vmap/Taylor) itself. | |||
| auto prim = GetValueNode<PrimitivePtr>(iter.first); | |||
| return (prim->name() != prim::kPrimJ->name() && prim->name() != prim::kPrimVmap->name()); | |||
| return (prim->name() != prim::kPrimJ->name() && prim->name() != prim::kPrimVmap->name() && | |||
| prim->name() != prim::kPrimTaylor->name()); | |||
| } | |||
| return false; | |||
| }); | |||
| @@ -1119,35 +1122,38 @@ bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, Se | |||
| } | |||
| } | |||
| // Check MetaFgPrim (J/Vmap) CNode as FV. | |||
| // Check MetaFgPrim (J/Vmap/Taylor) CNode as FV. | |||
| const auto &fv_nodes = fg->free_variables(); | |||
| if (!fv_nodes.empty()) { | |||
| auto contains_meta_fg_prim_cnode = std::find_if(fv_nodes.begin(), fv_nodes.end(), [seen_num](const auto &iter) { | |||
| // Check if the FV is a MetaFgPrim (J/Vmap) call CNode. | |||
| if (IsPrimitiveCNode(iter.first, prim::kPrimJ) || IsPrimitiveCNode(iter.first, prim::kPrimVmap)) { | |||
| // Check if the FV is a MetaFgPrim (J/Vmap/Taylor) call CNode. | |||
| if (IsPrimitiveCNode(iter.first, prim::kPrimJ) || IsPrimitiveCNode(iter.first, prim::kPrimVmap) || | |||
| IsPrimitiveCNode(iter.first, prim::kPrimTaylor)) { | |||
| return true; | |||
| } | |||
| return false; | |||
| }); | |||
| if (contains_meta_fg_prim_cnode != fv_nodes.end()) { | |||
| MS_LOG(DEBUG) << fg->ToString() << " contains FV MetaFgPrim (J/Vmap) (" | |||
| MS_LOG(DEBUG) << fg->ToString() << " contains FV MetaFgPrim (J/Vmap/Taylor) (" | |||
| << contains_meta_fg_prim_cnode->first->DebugString() << ")"; | |||
| return true; | |||
| } | |||
| } | |||
| // Check if func graphs used contains J(func_graph), J(Primitive), Vmap(func_graph) or Vmap(Primitive) | |||
| // Check if func graphs used contains J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive), | |||
| // Taylor(func_graph) or Taylor(Primitive). | |||
| fg->seen_ = seen_num; | |||
| for (auto &item : fg->func_graphs_used()) { | |||
| auto used_g = item.first; | |||
| if (SeekMetaFgPrim(used_g, seen_num)) { | |||
| MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() | |||
| << " which contains J(func_graph), J(Primitive), Vmap(func_graph) or Vmap(Primitive)"; | |||
| << " which contains J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive), " | |||
| << "Taylor(func_graph) or Taylor(Primitive)"; | |||
| return true; | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << fg->ToString() | |||
| << " doesn't contain J(func_graph), J(Primitive), Vmap(func_graph) or Vmap(Primitive)"; | |||
| MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive), " | |||
| << "Taylor(func_graph) or Taylor(Primitive)"; | |||
| return false; | |||
| } | |||
| @@ -15,7 +15,7 @@ | |||
| """grad impl.""" | |||
| from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \ | |||
| grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse, grad_inner_ops | |||
| from .grad_base import get_bprop_fn | |||
| grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops, grad_sparse, grad_inner_ops, taylor_rule | |||
| from .grad_base import get_bprop_fn, get_taylor_fprop_fn | |||
| __all__ = ['get_bprop_fn'] | |||
| __all__ = ['get_bprop_fn', 'get_taylor_fprop_fn'] | |||
| @@ -37,8 +37,28 @@ class BpropRegistry(Registry): | |||
| return deco | |||
| class TaylorFpropRegistry(Registry): | |||
| """Registry class for registry functions for taylor grad on Primitive or string.""" | |||
| def register(self, prim): | |||
| """register the function.""" | |||
| def deco(fn): | |||
| """Decorate the function.""" | |||
| if isinstance(prim, str): | |||
| self[prim] = fn | |||
| elif issubclass(prim, Primitive): | |||
| self[id(prim)] = fn | |||
| self[prim.__name__] = fn | |||
| return fn | |||
| return deco | |||
| bprop_getters = BpropRegistry() | |||
| bprops = BpropRegistry() | |||
| taylor_fprop_getters = TaylorFpropRegistry() | |||
| taylor_fprops = TaylorFpropRegistry() | |||
| def get_bprop_fn(prim): | |||
| @@ -47,3 +67,11 @@ def get_bprop_fn(prim): | |||
| if out: | |||
| return out(prim) | |||
| return bprops.get(prim, None) | |||
| def get_taylor_fprop_fn(prim): | |||
| """get taylor function by primitive obj or prim name for c++""" | |||
| out = taylor_fprop_getters.get(prim, None) | |||
| if out: | |||
| return out(prim) | |||
| return taylor_fprops.get(prim, None) | |||
| @@ -0,0 +1,166 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define the taylor rules of operations.""" | |||
| from mindspore import nn | |||
| import mindspore as ms | |||
| from ..primitive import Primitive | |||
| from .. import operations as P | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from .grad_base import taylor_fprop_getters | |||
| def _factorial(order): | |||
| """Return [0!, 1!, 2!,..., order!].""" | |||
| range_op = nn.Range(1, order + 1) | |||
| ones_op = P.Ones() | |||
| concat_op = P.Concat() | |||
| factorial_zero = ones_op(1, ms.float32) | |||
| factorial_positive = range_op().astype(ms.float32) | |||
| for i in range(1, order): | |||
| factorial_positive[i] *= factorial_positive[i - 1] | |||
| factorial = concat_op((factorial_zero, factorial_positive)) | |||
| return factorial | |||
| @taylor_fprop_getters.register(P.Add) | |||
| @taylor_fprop_getters.register(P.Sub) | |||
| def taylor_add_or_sub(self): | |||
| """Higher order derivatives rule definition for `Add` or `Sub`operation.""" | |||
| if isinstance(self, str): | |||
| prim = Primitive(self) | |||
| else: | |||
| prim = self | |||
| def taylor_fprop_add_or_sub(input_x, input_y): | |||
| series = prim(input_x, input_y) | |||
| return series | |||
| return taylor_fprop_add_or_sub | |||
| @taylor_fprop_getters.register(P.Mul) | |||
| def taylor_mul(self): | |||
| """Higher order derivatives rule definition for `Mul` operation.""" | |||
| mul_func = P.Mul() | |||
| def taylor_fprop_mul(input_x, input_y): | |||
| primals = mul_func(input_x[0], input_y[0]) | |||
| series_num = len(input_x) - 1 | |||
| factorial = _factorial(series_num) | |||
| series = zeros_like(input_x) | |||
| series[0] = primals | |||
| for k in range(1, series_num + 1): | |||
| for i in range(0, k + 1): | |||
| tmp = input_x[i] * input_y[k - i] / (factorial[k - i] * factorial[i]) | |||
| series[k] += tmp | |||
| series[k] *= factorial[k] | |||
| return series | |||
| return taylor_fprop_mul | |||
| @taylor_fprop_getters.register(P.RealDiv) | |||
| def taylor_realdiv(self): | |||
| """Higher order derivatives rule definition for `RealDiv` operation.""" | |||
| div_op = P.Div() | |||
| def taylor_fprop_realdiv(input_x, input_y): | |||
| primals = div_op(input_x[0], input_y[0]) | |||
| series_num = len(input_x) - 1 | |||
| factorial = _factorial(series_num) | |||
| series = zeros_like(input_x) | |||
| series[0] = primals | |||
| for k in range(1, series_num + 1): | |||
| for i in range(0, k): | |||
| tmp = series[i] * input_y[k - i] / (factorial[k - i] * factorial[i]) | |||
| series[k] += tmp | |||
| series[k] = (input_x[k] - factorial[k] * series[k]) / input_y[0] | |||
| return series | |||
| return taylor_fprop_realdiv | |||
| @taylor_fprop_getters.register(P.Exp) | |||
| def taylor_exp(self): | |||
| """Higher order derivatives rule definition for `Exp` operation.""" | |||
| exp_ = P.Exp() | |||
| def taylor_fprop_exp(inputs): | |||
| primals = exp_(inputs[0]) | |||
| series_num = len(inputs) - 1 | |||
| factorial = _factorial(series_num) | |||
| series = zeros_like(inputs) | |||
| series[0] = primals | |||
| for k in range(1, series_num + 1): | |||
| for i in range(1, k + 1): | |||
| tmp = i * inputs[i] * series[k - i] / (factorial[k - i] * factorial[i]) | |||
| series[k] += tmp | |||
| series[k] *= factorial[k - 1] | |||
| return series | |||
| return taylor_fprop_exp | |||
| @taylor_fprop_getters.register(P.Sin) | |||
| def taylor_sin(self): | |||
| """Higher order derivatives rule definition for `Sin` operation.""" | |||
| cos = P.Cos() | |||
| sin = P.Sin() | |||
| def taylor_fprop_sin(inputs): | |||
| primal_sin = sin(inputs[0]) | |||
| primal_cos = cos(inputs[0]) | |||
| series_sin = zeros_like(inputs) | |||
| series_cos = zeros_like(inputs) | |||
| series_sin[0] = primal_sin | |||
| series_cos[0] = primal_cos | |||
| series_num = len(inputs) - 1 | |||
| factorial = _factorial(series_num) | |||
| for k in range(1, series_num + 1): | |||
| for i in range(1, k + 1): | |||
| series_sin[k] += i * inputs[i] * series_cos[k - i] / (factorial[i] * factorial[k - i]) | |||
| series_cos[k] -= i * inputs[i] * series_sin[k - i] / (factorial[i] * factorial[k - i]) | |||
| series_sin[k] *= factorial[k - 1] | |||
| series_cos[k] *= factorial[k - 1] | |||
| return series_sin | |||
| return taylor_fprop_sin | |||
| @taylor_fprop_getters.register(P.Cos) | |||
| def taylor_cos(self): | |||
| """Higher order derivatives rule definition for `Cos` operation.""" | |||
| cos = P.Cos() | |||
| sin = P.Sin() | |||
| def taylor_fprop_cos(inputs): | |||
| primal_cos = cos(inputs[0]) | |||
| primal_sin = sin(inputs[0]) | |||
| series_cos = zeros_like(inputs) | |||
| series_sin = zeros_like(inputs) | |||
| series_cos[0] = primal_cos | |||
| series_sin[0] = primal_sin | |||
| series_num = len(inputs) - 1 | |||
| factorial = _factorial(series_num) | |||
| for k in range(1, series_num + 1): | |||
| for i in range(1, k + 1): | |||
| series_cos[k] -= i * inputs[i] * series_sin[k - i] / (factorial[i] * factorial[k - i]) | |||
| series_sin[k] += i * inputs[i] * series_cos[k - i] / (factorial[i] * factorial[k - i]) | |||
| series_cos[k] *= factorial[k - 1] | |||
| series_sin[k] *= factorial[k - 1] | |||
| return series_cos | |||
| return taylor_fprop_cos | |||
| @@ -21,7 +21,7 @@ Pre-defined combination of operators. | |||
| from .base import GradOperation, _Grad, HyperMap, Map, MultitypeFuncGraph, add_flags, \ | |||
| core, env_get, tail, zip_operation, Shard, _Vmap | |||
| core, env_get, tail, zip_operation, Shard, _Vmap, _TaylorOperation | |||
| from .clip_ops import clip_by_value, clip_by_global_norm | |||
| from .multitype_ops.add_impl import hyper_add | |||
| from .multitype_ops.ones_like_impl import ones_like | |||
| @@ -22,7 +22,7 @@ from types import FunctionType | |||
| from mindspore import context | |||
| from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \ | |||
| TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \ | |||
| ListSlice_, VmapOperation_ | |||
| ListSlice_, VmapOperation_, TaylorOperation_ | |||
| from ...common import dtype as mstype | |||
| from ...common.api import ms_function, _pynative_executor, _wrap_func | |||
| from ..primitive import Primitive | |||
| @@ -401,6 +401,29 @@ class GradOperation(GradOperation_): | |||
| return self.grad_fn | |||
| class _TaylorOperation(TaylorOperation_): | |||
| """ | |||
| Generate the higher order derivatives function for the input function. | |||
| """ | |||
| def __init__(self): | |||
| """Initialize TaylorOperation.""" | |||
| TaylorOperation_.__init__(self, 'taylorgrad') | |||
| self.grad_fn = None | |||
| self.fn = None | |||
| def __call__(self, fn): | |||
| if self.grad_fn is not None and self.fn == fn: | |||
| return self.grad_fn | |||
| taylor_grad_ = _TaylorOperation() | |||
| # If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE | |||
| @ms_function | |||
| def after_taylor_grad(*args): | |||
| return taylor_grad_(fn)(*args) | |||
| self.grad_fn = after_taylor_grad | |||
| self.fn = fn | |||
| return self.grad_fn | |||
| class _Grad(GradOperation_): | |||
| """ | |||
| A higher-order function which is used to generate the gradient function by position for the input function. | |||
| @@ -31,7 +31,7 @@ from .primitive import Primitive | |||
| from . import operations as P | |||
| from .operations import _grad_ops | |||
| from .operations import _csr_ops | |||
| from .composite import _Grad, Shard, _Vmap | |||
| from .composite import _Grad, Shard, _Vmap, _TaylorOperation | |||
| from .._c_expression import security | |||
| typeof = Primitive('typeof') | |||
| @@ -47,6 +47,7 @@ eye = P.Eye() | |||
| fill = P.Fill() | |||
| tile = P.Tile() | |||
| size = P.Size() | |||
| ones = P.Ones() | |||
| ones_like = P.OnesLike() | |||
| shape = P.Shape() | |||
| dyn_shape = P.TensorShape() | |||
| @@ -284,6 +285,187 @@ def grad(fn, grad_position=0, sens_param=False): | |||
| return grad_by_position_with_sens(fn, None, grad_position) | |||
| return grad_by_position(fn, None, grad_position) | |||
| @constexpr | |||
| def _trans_jet_inputs(primals_item, series_item): | |||
| """Trans inputs of jet""" | |||
| value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64] | |||
| if not dtype(primals_item) in value_type or dtype(primals_item) != dtype(series_item): | |||
| raise TypeError(f"For `F.jet`, the elements' types of primals and series should be the same and belong to " | |||
| f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got" | |||
| f" {dtype(primals_item).__name__} and {dtype(series_item).__name__}.") | |||
| if dtype(primals_item) in [mstype.int32, mstype.int64]: | |||
| return cast(primals_item, mstype.float64), cast(series_item, mstype.float64) | |||
| return primals_item, series_item | |||
| @constexpr | |||
| def _check_jet_inputs(primals, series): | |||
| """Check inputs of jet""" | |||
| if not isinstance(primals, type(series)) or not isinstance(primals, (Tensor, tuple)): | |||
| raise TypeError(f"For 'F.jet', the 'primals' and `series` should be both Tensor or tuple, " | |||
| f"but got {type(primals).__name__} and {type(series).__name__}.") | |||
| if isinstance(primals, Tensor): | |||
| if primals.shape != series.shape[1:]: | |||
| raise ValueError("The shape of each element should be the same as the primals.") | |||
| return _trans_jet_inputs(primals, series) | |||
| if isinstance(primals, tuple): | |||
| if len(primals) != len(series): | |||
| raise ValueError("The lengths of primals and series should be the same.") | |||
| check_primals = [] | |||
| check_series = [] | |||
| for i, j in zip(primals, series): | |||
| trans_primals_item, trans_series_item = _trans_jet_inputs(i, j) | |||
| check_primals.append(trans_primals_item) | |||
| check_series.append(trans_series_item) | |||
| return check_primals, check_series | |||
| _taylor = _TaylorOperation() | |||
| def jet(fn, primals, series): | |||
| """ | |||
| This function is designed to calculate the higher order differentiation of given composite function. To figure out | |||
| first to `n`-th order differentiations, original inputs and first to `n`-th order derivative of original inputs | |||
| must be provided together. Generally, it is recommended to set the values of given first order derivative to 1, | |||
| while the other to 0. | |||
| Args: | |||
| fn (Union(Cell, function)): Function to do TaylorOperation. | |||
| primals (Union(Tensor, Tuple of Tensors)): The inputs to `fn`. | |||
| series (Union(Tensor, Tuple of Tensors)): If tuple, the length and type of series should be the same as inputs. | |||
| For each Tensor, the length of first dimension `i` represents the `1` to `i+1`-th order of derivative of | |||
| output with respect to the inputs will be figured out. | |||
| Returns: | |||
| Tuple, tuple of out_primals and out_series. | |||
| - **out_primals** (Tensors or List of Tensors) - The output of `fn(primals)`. | |||
| - **out_series** (Tensors or List of Tensors) - The `1` to `i+1`-th order of derivative of output with respect | |||
| to the inputs. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import numpy as np | |||
| >>> import mindspore.nn as nn | |||
| >>> import mindspore.context as context | |||
| >>> import mindspore.ops as P | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.ops.functional import jet | |||
| >>> context.set_context(mode=context.GRAPH_MODE) | |||
| >>> class Net(nn.Cell): | |||
| ... def __init__(self): | |||
| ... super().__init__() | |||
| ... self.sin = P.Sin() | |||
| ... self.exp = P.Exp() | |||
| ... def construct(self, x): | |||
| ... out1 = self.sin(x) | |||
| ... out2 = self.exp(out1) | |||
| ... return out2 | |||
| >>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) | |||
| >>> series = Tensor(np.array([[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]]).astype(np.float32)) | |||
| >>> net = Net() | |||
| >>> out_primals, out_series = jet(net, primals, series) | |||
| >>> print(out_primals, out_series) | |||
| """ | |||
| primals, series = _check_jet_inputs(primals, series) | |||
| derivative_fn = _taylor(fn) | |||
| concat_op = P.Concat() | |||
| if isinstance(primals, list) and list_len(primals) > 1: | |||
| inputs = list(map(lambda x, y: concat_op(((expand_dims(x, 0), y))), primals, series)) | |||
| outputs = derivative_fn(*inputs) | |||
| else: | |||
| inputs = concat_op((expand_dims(primals, 0), series)) | |||
| outputs = derivative_fn(inputs) | |||
| if isinstance(outputs, list) and list_len(outputs) > 1: | |||
| out_primals = [element[0] for element in outputs] | |||
| out_series = [element[1:] for element in outputs] | |||
| else: | |||
| out_primals = outputs[0] | |||
| out_series = outputs[1:] | |||
| return out_primals, out_series | |||
| @constexpr | |||
| def _trans_derivative_inputs(primals_item): | |||
| """Trans inputs of derivative""" | |||
| value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64] | |||
| if not dtype(primals_item) in value_type: | |||
| raise TypeError(f"For `F.derivative`, the elements of primals should belong to " | |||
| f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got" | |||
| f" {dtype(primals_item).__name__}.") | |||
| if dtype(primals_item) in [mstype.int32, mstype.int64]: | |||
| return cast(primals_item, mstype.float64) | |||
| return primals_item | |||
| def derivative(fn, primals, order): | |||
| """ | |||
| This function is designed to calculate the higher order differentiation of given composite function. To figure out | |||
| `order`-th order differentiations, original inputs and order must be provided together. In particular, the value of | |||
| input first order derivative is set to 1, while the other to 0. | |||
| Args: | |||
| fn (Union(Cell, function)): Function to do TaylorOperation. | |||
| primals (Union(Tensor, Tuple of Tensors)): The inputs to `fn`. | |||
| order (int): For each Tensor, the `order`-th order of derivative of output with respect to the inputs will be | |||
| figured out. | |||
| Returns: | |||
| Tuple, tuple of out_primals and out_series. | |||
| - **out_primals** (Tensors or List of Tensors) - The output of `fn(primals)`. | |||
| - **out_series** (Tensors or List of Tensors) - The `order`-th order of derivative of output with respect | |||
| to the inputs. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> import numpy as np | |||
| >>> import mindspore.nn as nn | |||
| >>> import mindspore.context as context | |||
| >>> import mindspore.ops as P | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.ops.functional import derivative | |||
| >>> context.set_context(mode=context.GRAPH_MODE) | |||
| >>> class Net(nn.Cell): | |||
| ... def __init__(self): | |||
| ... super().__init__() | |||
| ... self.sin = P.Sin() | |||
| ... self.exp = P.Exp() | |||
| ... def construct(self, x): | |||
| ... out1 = self.sin(x) | |||
| ... out2 = self.exp(out1) | |||
| ... return out2 | |||
| >>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) | |||
| >>> order = 3 | |||
| >>> net = Net() | |||
| >>> out_primals, out_series = derivative(net, primals, order) | |||
| >>> print(out_primals, out_series) | |||
| """ | |||
| derivative_fn = _taylor(fn) | |||
| concat_op = P.Concat() | |||
| series_one = 1 | |||
| if isinstance(primals, tuple): | |||
| trans_primals = [_trans_derivative_inputs(item) for item in primals] | |||
| inputs = list(map(lambda x: concat_op((expand_dims(x, 0), ones((1,) + x.shape, dtype(x)))), trans_primals)) | |||
| if order > 1: | |||
| inputs = list(map(lambda x: concat_op((x, zeros((order - 1,) + x[0].shape, dtype(x)))), inputs)) | |||
| outputs = derivative_fn(*inputs) | |||
| else: | |||
| primals = _trans_derivative_inputs(primals) | |||
| series = zeros((order,) + primals.shape, dtype(primals)) | |||
| series[0] = series_one | |||
| inputs = concat_op((expand_dims(primals, 0), series)) | |||
| outputs = derivative_fn(inputs) | |||
| if isinstance(outputs, tuple) and tuple_len(outputs) > 1: | |||
| out_primals = [element[0] for element in outputs] | |||
| out_series = [element[-1] for element in outputs] | |||
| else: | |||
| out_primals = outputs[0] | |||
| out_series = outputs[-1] | |||
| return out_primals, out_series | |||
| def jvp(fn, inputs, v): | |||
| """ | |||
| @@ -38,6 +38,11 @@ class MultipleInputsOutputNet(nn.Cell): | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vjp_single_input_graph(): | |||
| """ | |||
| Features: Function vjp | |||
| Description: Test vjp with single input, single output and default v in graph mode. | |||
| Expectation: No exception. | |||
| """ | |||
| x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) | |||
| v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) | |||
| net = SingleInputNet() | |||
| @@ -53,6 +58,11 @@ def test_vjp_single_input_graph(): | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vjp_multiple_inputs_default_v_graph(): | |||
| """ | |||
| Features: Function vjp | |||
| Description: Test vjp with single input, single output and default v in graph mode. | |||
| Expectation: No exception. | |||
| """ | |||
| x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) | |||
| y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) | |||
| v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) | |||
| @@ -140,8 +150,8 @@ def test_vjp_construct_single_input_single_output_default_v_graph(): | |||
| net_out, vjp_out = vjp(self.net, inputs, vectors) | |||
| return net_out, vjp_out | |||
| test_net = Net(SingleInputNet()) | |||
| primal, grad = test_net(x, v) | |||
| test_net_graph = Net(SingleInputNet()) | |||
| primal, grad = test_net_graph(x, v) | |||
| expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) | |||
| expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) | |||
| assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) | |||
| @@ -18,7 +18,6 @@ import pytest | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore import ms_function | |||
| from mindspore.ops.functional import vjp | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| @@ -37,12 +36,17 @@ class MultipleInputsOutputNet(nn.Cell): | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vjp_single_input_graph(): | |||
| def test_vjp_single_input_pynative(): | |||
| """ | |||
| Features: Function vjp | |||
| Description: Test vjp with single input, single output and default v in pynative mode. | |||
| Expectation: No exception. | |||
| """ | |||
| x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) | |||
| v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) | |||
| net = SingleInputNet() | |||
| expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) | |||
| expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) | |||
| expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) | |||
| primal, grad = vjp(net, x, v) | |||
| assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) | |||
| assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) | |||
| @@ -51,58 +55,38 @@ def test_vjp_single_input_graph(): | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vjp_multiple_inputs_default_v_graph(): | |||
| def test_vjp_multiple_inputs_default_v_pynative(): | |||
| """ | |||
| Features: Function vjp | |||
| Description: Test vjp with multiple inputs, multiple outputs and default v in pynative mode. | |||
| Expectation: No exception. | |||
| """ | |||
| x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) | |||
| y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) | |||
| v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) | |||
| net = MultipleInputsOutputNet() | |||
| expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) | |||
| expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) | |||
| expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) | |||
| expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) | |||
| expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) | |||
| expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) | |||
| primal, grad = vjp(net, (x, y), (v, v)) | |||
| assert isinstance(primal, tuple) | |||
| assert len(primal) == 2 | |||
| assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) | |||
| assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) | |||
| assert isinstance(grad, tuple) | |||
| assert len(grad) == 2 | |||
| assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) | |||
| assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) | |||
| assert isinstance(primal, tuple) | |||
| assert len(primal) == 2 | |||
| assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) | |||
| assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vjp_ms_function_single_input_single_output_default_v_graph(): | |||
| """ | |||
| Features: Function vjp | |||
| Description: Test vjp with ms_function, single input, single output and default v in graph mode. | |||
| Expectation: No exception. | |||
| """ | |||
| x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) | |||
| v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) | |||
| net = SingleInputNet() | |||
| @ms_function | |||
| def vjp_with_ms_function(inputs, vectors): | |||
| output, vjp_grad = vjp(net, inputs, vectors) | |||
| return output, vjp_grad | |||
| primal, grad = vjp_with_ms_function(x, v) | |||
| expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) | |||
| expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) | |||
| assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) | |||
| assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vjp_input_function_single_input_single_output_default_v_graph(): | |||
| def test_vjp_input_function_single_input_single_output_default_v_pynative(): | |||
| """ | |||
| Features: Function vjp | |||
| Description: Test vjp with function, single input, single output and default v in graph mode. | |||
| Description: Test vjp with function, single input, single output and default v in pynative mode. | |||
| Expectation: No exception. | |||
| """ | |||
| x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) | |||
| @@ -112,8 +96,8 @@ def test_vjp_input_function_single_input_single_output_default_v_graph(): | |||
| return inputs**3 | |||
| primal, grad = vjp(test_function, x, v) | |||
| expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) | |||
| expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) | |||
| expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) | |||
| assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) | |||
| assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) | |||
| @@ -121,10 +105,10 @@ def test_vjp_input_function_single_input_single_output_default_v_graph(): | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vjp_construct_single_input_single_output_default_v_graph(): | |||
| def test_vjp_construct_single_input_single_output_default_v_pynative(): | |||
| """ | |||
| Features: Function vjp | |||
| Description: Test vjp with function, single input, single output and default v in graph mode. | |||
| Description: Test vjp with function, single input, single output and default v in pynative mode. | |||
| Expectation: No exception. | |||
| """ | |||
| x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) | |||
| @@ -139,8 +123,8 @@ def test_vjp_construct_single_input_single_output_default_v_graph(): | |||
| net_out, vjp_out = vjp(self.net, inputs, vectors) | |||
| return net_out, vjp_out | |||
| test_net = Net(SingleInputNet()) | |||
| primal, grad = test_net(x, v) | |||
| test_net_pynative = Net(SingleInputNet()) | |||
| primal, grad = test_net_pynative(x, v) | |||
| expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) | |||
| expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) | |||
| assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) | |||
| @@ -0,0 +1,142 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test taylor differentiation in graph mode""" | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| from mindspore.ops import operations as P | |||
| from mindspore import Tensor | |||
| from mindspore.ops.functional import jet, derivative | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| class MultipleInputSingleOutputNet(nn.Cell): | |||
| def __init__(self): | |||
| super(MultipleInputSingleOutputNet, self).__init__() | |||
| self.sin = P.Sin() | |||
| self.cos = P.Cos() | |||
| self.exp = P.Exp() | |||
| def construct(self, x, y): | |||
| out1 = self.sin(x) | |||
| out2 = self.cos(y) | |||
| out3 = out1 * out2 + out1 / out2 | |||
| out = self.exp(out3) | |||
| return out | |||
| class SingleInputSingleOutputNet(nn.Cell): | |||
| def __init__(self): | |||
| super(SingleInputSingleOutputNet, self).__init__() | |||
| self.sin = P.Sin() | |||
| self.cos = P.Cos() | |||
| self.exp = P.Exp() | |||
| def construct(self, x): | |||
| out1 = self.sin(x) | |||
| out2 = self.cos(out1) | |||
| out3 = self.exp(out2) | |||
| out = out1 + out2 - out3 | |||
| return out | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_jet_single_input_single_output_graph_mode(): | |||
| """ | |||
| Features: Function jet | |||
| Description: Test jet with single input in graph mode. | |||
| Expectation: No exception. | |||
| """ | |||
| primals = Tensor([1., 1.]) | |||
| series = Tensor([[1., 1.], [0., 0.], [0., 0.]]) | |||
| net = SingleInputSingleOutputNet() | |||
| expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32) | |||
| expected_series = np.array([[0.92187, 0.92187], [-1.56750, -1.56750], [-0.74808, -0.74808]]).astype(np.float32) | |||
| out_primals, out_series = jet(net, primals, series) | |||
| assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4) | |||
| assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_derivative_single_input_single_output_graph_mode(): | |||
| """ | |||
| Features: Function derivative | |||
| Description: Test derivative with single input in graph mode. | |||
| Expectation: No exception. | |||
| """ | |||
| primals = Tensor([1., 1.]) | |||
| order = 3 | |||
| net = SingleInputSingleOutputNet() | |||
| expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32) | |||
| expected_series = np.array([-0.74808, -0.74808]).astype(np.float32) | |||
| out_primals, out_series = derivative(net, primals, order) | |||
| assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4) | |||
| assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_jet_multiple_input_single_output_graph_mode(): | |||
| """ | |||
| Features: Function jet | |||
| Description: Test jet with multiple inputs in graph mode. | |||
| Expectation: No exception. | |||
| """ | |||
| primals = (Tensor([1., 1.]), Tensor([1., 1.])) | |||
| series = (Tensor([[1., 1.], [0., 0.], [0., 0.]]), Tensor([[1., 1.], [0., 0.], [0., 0.]])) | |||
| net = MultipleInputSingleOutputNet() | |||
| expected_primals = np.array([7.47868, 7.47868]).astype(np.float32) | |||
| expected_series = np.array([[22.50614, 22.50614], [133.92517, 133.92517], [1237.959, 1237.959]]).astype(np.float32) | |||
| out_primals, out_series = jet(net, primals, series) | |||
| assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4) | |||
| assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_derivative_multiple_input_single_output_graph_mode(): | |||
| """ | |||
| Features: Function derivative | |||
| Description: Test derivative with multiple inputs in graph mode. | |||
| Expectation: No exception. | |||
| """ | |||
| primals = (Tensor([1., 1.]), Tensor([1., 1.])) | |||
| order = 3 | |||
| net = MultipleInputSingleOutputNet() | |||
| expected_primals = np.array([7.47868, 7.47868]).astype(np.float32) | |||
| expected_series = np.array([1237.959, 1237.959]).astype(np.float32) | |||
| out_primals, out_series = derivative(net, primals, order) | |||
| assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4) | |||
| assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4) | |||
| @@ -0,0 +1,142 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test taylor differentiation in pynative mode""" | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| from mindspore.ops import operations as P | |||
| from mindspore import Tensor | |||
| from mindspore.ops.functional import jet, derivative | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| class SingleInputSingleOutputNet(nn.Cell): | |||
| def __init__(self): | |||
| super(SingleInputSingleOutputNet, self).__init__() | |||
| self.exp = P.Exp() | |||
| self.cos = P.Cos() | |||
| self.sin = P.Sin() | |||
| def construct(self, x): | |||
| out1 = self.sin(x) | |||
| out2 = self.cos(out1) | |||
| out3 = self.exp(out2) | |||
| out = out1 + out2 - out3 | |||
| return out | |||
| class MultipleInputSingleOutputNet(nn.Cell): | |||
| def __init__(self): | |||
| super(MultipleInputSingleOutputNet, self).__init__() | |||
| self.exp = P.Exp() | |||
| self.cos = P.Cos() | |||
| self.sin = P.Sin() | |||
| def construct(self, x, y): | |||
| out1 = self.sin(x) | |||
| out2 = self.cos(y) | |||
| out3 = out1 * out2 + out1 / out2 | |||
| out = self.exp(out3) | |||
| return out | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_jet_multiple_input_single_output_pynative_mode(): | |||
| """ | |||
| Features: Function jet | |||
| Description: Test jet with multiple inputs in pynative mode. | |||
| Expectation: No exception. | |||
| """ | |||
| series = (Tensor([[1., 1.], [0., 0.], [0., 0.]]), Tensor([[1., 1.], [0., 0.], [0., 0.]])) | |||
| primals = (Tensor([1., 1.]), Tensor([1., 1.])) | |||
| net = MultipleInputSingleOutputNet() | |||
| expected_primals = np.array([7.47868, 7.47868]).astype(np.float32) | |||
| expected_series = np.array([[22.50614, 22.50614], [133.92517, 133.92517], [1237.959, 1237.959]]).astype(np.float32) | |||
| out_primals, out_series = jet(net, primals, series) | |||
| assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4) | |||
| assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_derivative_multiple_input_single_output_pynative_mode(): | |||
| """ | |||
| Features: Function derivative | |||
| Description: Test derivative with multiple inputs in pynative mode. | |||
| Expectation: No exception. | |||
| """ | |||
| primals = (Tensor([1., 1.]), Tensor([1., 1.])) | |||
| order = 3 | |||
| net = MultipleInputSingleOutputNet() | |||
| expected_primals = np.array([7.47868, 7.47868]).astype(np.float32) | |||
| expected_series = np.array([1237.959, 1237.959]).astype(np.float32) | |||
| out_primals, out_series = derivative(net, primals, order) | |||
| assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4) | |||
| assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_jet_single_input_single_output_pynative_mode(): | |||
| """ | |||
| Features: Function jet | |||
| Description: Test jet with single input in pynative mode. | |||
| Expectation: No exception. | |||
| """ | |||
| primals = Tensor([1., 1.]) | |||
| series = Tensor([[1., 1.], [0., 0.], [0., 0.]]) | |||
| net = SingleInputSingleOutputNet() | |||
| expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32) | |||
| expected_series = np.array([[0.92187, 0.92187], [-1.56750, -1.56750], [-0.74808, -0.74808]]).astype(np.float32) | |||
| out_primals, out_series = jet(net, primals, series) | |||
| assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4) | |||
| assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_derivative_single_input_single_output_pynative_mode(): | |||
| """ | |||
| Features: Function derivative | |||
| Description: Test derivative with single input in pynative mode. | |||
| Expectation: No exception. | |||
| """ | |||
| primals = Tensor([1., 1.]) | |||
| order = 3 | |||
| net = SingleInputSingleOutputNet() | |||
| expected_primals = np.array([-0.43931, -0.43931]).astype(np.float32) | |||
| expected_series = np.array([-0.74808, -0.74808]).astype(np.float32) | |||
| out_primals, out_series = derivative(net, primals, order) | |||
| assert np.allclose(out_primals.asnumpy(), expected_primals, atol=1.e-4) | |||
| assert np.allclose(out_series.asnumpy(), expected_series, atol=1.e-4) | |||