| @@ -30,6 +30,7 @@ | |||||
| #include "pipeline/parse/python_adapter.h" | #include "pipeline/parse/python_adapter.h" | ||||
| #include "pipeline/parse/resolve.h" | #include "pipeline/parse/resolve.h" | ||||
| #include "operator/composite/composite.h" | #include "operator/composite/composite.h" | ||||
| #include "operator/composite/map.h" | |||||
| #include "utils/ordered_map.h" | #include "utils/ordered_map.h" | ||||
| #include "utils/ordered_set.h" | #include "utils/ordered_set.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| @@ -190,6 +191,8 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap | |||||
| * ├── MultitypeGraph | * ├── MultitypeGraph | ||||
| * ├── HyperMap | * ├── HyperMap | ||||
| * │ └── HyperMapPy | * │ └── HyperMapPy | ||||
| * ├── Map | |||||
| * │ └── MapPy | |||||
| * ├── Tail | * ├── Tail | ||||
| * ├── MakeTupleGradient | * ├── MakeTupleGradient | ||||
| * ├── GradOperation | * ├── GradOperation | ||||
| @@ -208,17 +211,25 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_ | |||||
| oss << GetMultitypeFuncGraphText(mt_func_graph); | oss << GetMultitypeFuncGraphText(mt_func_graph); | ||||
| } else if (meta_func_graph | } else if (meta_func_graph | ||||
| ->isa<prim::HyperMapPy>()) { // this statement must before 'meta_graph->isa<prim::HyperMap>()' | ->isa<prim::HyperMapPy>()) { // this statement must before 'meta_graph->isa<prim::HyperMap>()' | ||||
| prim::HyperMapPyPtr hyper_map = meta_func_graph->cast<prim::HyperMapPyPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(hyper_map); | |||||
| auto hyper_map = meta_func_graph->cast<prim::HyperMapPyPtr>(); | |||||
| if (hyper_map->GetFnLeaf() != nullptr) { | if (hyper_map->GetFnLeaf() != nullptr) { | ||||
| oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}"; | oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}"; | ||||
| } | } | ||||
| } else if (meta_func_graph->isa<prim::HyperMap>()) { | } else if (meta_func_graph->isa<prim::HyperMap>()) { | ||||
| prim::HyperMapPtr hyper_map = meta_func_graph->cast<prim::HyperMapPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(hyper_map); | |||||
| auto hyper_map = meta_func_graph->cast<prim::HyperMapPtr>(); | |||||
| if (hyper_map->GetFnLeaf() != nullptr) { | if (hyper_map->GetFnLeaf() != nullptr) { | ||||
| oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}"; | oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}"; | ||||
| } | } | ||||
| } else if (meta_func_graph->isa<prim::MapPy>()) { // this statement must before 'meta_graph->isa<prim::Map>()' | |||||
| auto map = meta_func_graph->cast<prim::MapPyPtr>(); | |||||
| if (map->GetFnLeaf() != nullptr) { | |||||
| oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}"; | |||||
| } | |||||
| } else if (meta_func_graph->isa<prim::Map>()) { | |||||
| auto map = meta_func_graph->cast<prim::MapPtr>(); | |||||
| if (map->GetFnLeaf() != nullptr) { | |||||
| oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}"; | |||||
| } | |||||
| } else if (meta_func_graph->isa<prim::GradOperation>()) { | } else if (meta_func_graph->isa<prim::GradOperation>()) { | ||||
| prim::GradOperationPtr grad_op = meta_func_graph->cast<prim::GradOperationPtr>(); | prim::GradOperationPtr grad_op = meta_func_graph->cast<prim::GradOperationPtr>(); | ||||
| oss << "{get_all=" << grad_op->get_all_ << ", get_by_list=" << grad_op->get_by_list_ | oss << "{get_all=" << grad_op->get_all_ << ", get_by_list=" << grad_op->get_by_list_ | ||||
| @@ -0,0 +1,289 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "operator/composite/map.h" | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "ir/anf.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "pipeline/static_analysis/abstract_value.h" | |||||
| #include "pipeline/static_analysis/abstract_function.h" | |||||
| #include "pipeline/static_analysis/dshape.h" | |||||
| #include "pybind_api/api_register.h" | |||||
| #include "debug/trace.h" | |||||
| #include "operator/ops.h" | |||||
| #include "./common.h" | |||||
| namespace mindspore { | |||||
| // namespace to support composite operators definition | |||||
| namespace prim { | |||||
| using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure; | |||||
| AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) { | |||||
| MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n"; | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| if (fn_arg != nullptr) { | |||||
| inputs.emplace_back(fn_arg); | |||||
| } else { | |||||
| inputs.emplace_back(NewValueNode(fn_leaf_)); | |||||
| } | |||||
| inputs.insert(inputs.end(), args.begin(), args.end()); | |||||
| return func_graph->NewCNode(inputs); | |||||
| } | |||||
| FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) { | |||||
| // Generate func for leaf nodes | |||||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | |||||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||||
| ptrGraph->debug_info()->set_name("map"); | |||||
| AnfNodePtr ptrFnArg = nullptr; | |||||
| if (fn_leaf_ == nullptr) { | |||||
| ptrFnArg = ptrGraph->add_parameter(); | |||||
| } | |||||
| AnfNodePtrList args; | |||||
| for (size_t i = 0; i < args_size; ++i) { | |||||
| args.emplace_back(ptrGraph->add_parameter()); | |||||
| } | |||||
| ptrGraph->set_output(FullMakeLeaf(ptrGraph, ptrFnArg, args)); | |||||
| return ptrGraph; | |||||
| } | |||||
| AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, | |||||
| const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(type); | |||||
| std::size_t size = type->elements().size(); | |||||
| bool is_not_same = | |||||
| std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) { | |||||
| auto lhs = std::dynamic_pointer_cast<List>(item.second); | |||||
| MS_EXCEPTION_IF_NULL(lhs); | |||||
| return lhs->elements().size() != size; | |||||
| }); | |||||
| if (is_not_same) { | |||||
| MS_LOG(EXCEPTION) << "List in Map should have same length"; | |||||
| } | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.push_back(NewValueNode(prim::kPrimMakeList)); | |||||
| for (int i = 0; i < SizeToInt(size); ++i) { | |||||
| MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the target"; | |||||
| auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); | |||||
| auto fn = NewValueNode(ptrGraph); | |||||
| std::vector<AnfNodePtr> inputs2; | |||||
| inputs2.push_back(fn); | |||||
| if (fn_arg != nullptr) { | |||||
| inputs2.push_back(fn_arg); | |||||
| } | |||||
| (void)std::transform( | |||||
| arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), | |||||
| [&func_graph, i](const std::pair<AnfNodePtr, Any> &item) { | |||||
| return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); | |||||
| }); | |||||
| inputs.push_back(func_graph->NewCNode(inputs2)); | |||||
| } | |||||
| return func_graph->NewCNode(inputs); | |||||
| } | |||||
| AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, | |||||
| const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| MS_EXCEPTION_IF_NULL(type); | |||||
| std::size_t size = type->elements().size(); | |||||
| bool is_not_same = | |||||
| std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) { | |||||
| auto lhs = std::dynamic_pointer_cast<Tuple>(item.second); | |||||
| MS_EXCEPTION_IF_NULL(lhs); | |||||
| return lhs->elements().size() != size; | |||||
| }); | |||||
| if (is_not_same) { | |||||
| MS_LOG(EXCEPTION) << "tuple in Map should have same length"; | |||||
| } | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||||
| for (int i = 0; i < SizeToInt(size); ++i) { | |||||
| MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the tuple inputs"; | |||||
| auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); | |||||
| auto fn = NewValueNode(ptrGraph); | |||||
| std::vector<AnfNodePtr> inputs2; | |||||
| inputs2.push_back(fn); | |||||
| if (fn_arg != nullptr) { | |||||
| inputs2.push_back(fn_arg); | |||||
| } | |||||
| (void)std::transform( | |||||
| arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), | |||||
| [&func_graph, &i](std::pair<AnfNodePtr, Any> item) { | |||||
| return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); | |||||
| }); | |||||
| inputs.push_back(func_graph->NewCNode(inputs2)); | |||||
| } | |||||
| return func_graph->NewCNode(inputs); | |||||
| } | |||||
| AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, | |||||
| const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { | |||||
| MS_EXCEPTION_IF_NULL(type); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); | |||||
| inputs.push_back(NewValueNode(type)); | |||||
| std::size_t attrSize = type->GetAttributes().size(); | |||||
| for (std::size_t i = 0; i < attrSize; ++i) { | |||||
| MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th element of the inputs"; | |||||
| auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); | |||||
| auto fn = NewValueNode(ptrGraph); | |||||
| std::vector<AnfNodePtr> inputs2; | |||||
| inputs2.push_back(fn); | |||||
| if (fn_arg != nullptr) { | |||||
| inputs2.push_back(fn_arg); | |||||
| } | |||||
| int j = 0; | |||||
| for (auto item : arg_pairs) { | |||||
| inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); | |||||
| j++; | |||||
| } | |||||
| inputs.push_back(func_graph->NewCNode(inputs2)); | |||||
| } | |||||
| return func_graph->NewCNode(inputs); | |||||
| } | |||||
| AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { | |||||
| bool found = false; | |||||
| TypeId id = kObjectTypeEnd; | |||||
| std::pair<AnfNodePtr, TypePtr> pair; | |||||
| for (auto &item : arg_pairs) { | |||||
| pair = item; | |||||
| MS_LOG(DEBUG) << "Map " << pair.second->ToString(); | |||||
| id = item.second->type_id(); | |||||
| if (nonleaf_.count(id)) { | |||||
| found = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (found) { | |||||
| // In a nonleaf situation, all arguments must have the same generic. | |||||
| bool is_not_same = | |||||
| std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) { | |||||
| if (item.first != pair.first) { | |||||
| return item.second->type_id() != pair.second->type_id(); | |||||
| } | |||||
| return false; | |||||
| }); | |||||
| if (is_not_same) { | |||||
| std::ostringstream oss; | |||||
| oss << "There are " << arg_pairs.size() << " inputs of `" << name_ << "`, corresponding type info:\n" | |||||
| << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; | |||||
| int idx = 0; | |||||
| for (auto &item : arg_pairs) { | |||||
| oss << ++idx << ": " << item.second->ToString() << "\n"; | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Map cannot match up all input types of arguments.\n" | |||||
| << oss.str() << pair.second->ToString() << "\n"; | |||||
| } | |||||
| } | |||||
| switch (id) { | |||||
| case kObjectTypeList: { | |||||
| auto type = std::static_pointer_cast<List>(pair.second); | |||||
| return FullMakeList(type, func_graph, fn_arg, arg_pairs); | |||||
| } | |||||
| case kObjectTypeTuple: { | |||||
| auto type = std::static_pointer_cast<Tuple>(pair.second); | |||||
| return FullMakeTuple(type, func_graph, fn_arg, arg_pairs); | |||||
| } | |||||
| case kObjectTypeClass: { | |||||
| auto type = std::static_pointer_cast<Class>(pair.second); | |||||
| return FullMakeClass(type, func_graph, fn_arg, arg_pairs); | |||||
| } | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class " | |||||
| << ", but got " << pair.second->ToString(); | |||||
| } | |||||
| } | |||||
| FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) { | |||||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | |||||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||||
| ptrGraph->debug_info()->set_name("map"); | |||||
| AnfNodePtr ptrFnArg = nullptr; | |||||
| std::size_t i = 0; | |||||
| if (fn_leaf_ == nullptr) { | |||||
| ptrFnArg = ptrGraph->add_parameter(); | |||||
| i = 1; | |||||
| } | |||||
| ArgsPairList arg_pairs; | |||||
| std::size_t size = args_spec_list.size(); | |||||
| for (; i < size; ++i) { | |||||
| MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_spec_list[i]->ToString(); | |||||
| arg_pairs.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); | |||||
| } | |||||
| ptrGraph->set_output(Make(ptrGraph, ptrFnArg, arg_pairs)); | |||||
| return ptrGraph; | |||||
| } | |||||
| abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { | |||||
| if (fn_leaf_ == nullptr) { | |||||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||||
| // Assert that map's function param does not contain free variables | |||||
| if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) { | |||||
| auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]); | |||||
| auto func_graph = graph_func->func_graph(); | |||||
| if (func_graph->parent() != nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Map don't support Closure with free variable yet."; | |||||
| } | |||||
| } | |||||
| } | |||||
| AbstractBasePtrList broadened; | |||||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), | |||||
| [](const AbstractBasePtr &arg) -> AbstractBasePtr { | |||||
| MS_EXCEPTION_IF_NULL(arg); | |||||
| return arg->Broaden(); | |||||
| }); | |||||
| return broadened; | |||||
| } | |||||
| REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) { | |||||
| (void)py::class_<MapPy, MetaFuncGraph, std::shared_ptr<MapPy>>(*m, "Map_") | |||||
| .def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf")) | |||||
| .def(py::init<>()); | |||||
| })); | |||||
| } // namespace prim | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,98 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ | |||||
| #define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "ir/dtype.h" | |||||
| #include "ir/meta_func_graph.h" | |||||
| #include "operator/composite/multitype_funcgraph.h" | |||||
| namespace mindspore { | |||||
| // namespace to support composite operators definition | |||||
| namespace prim { | |||||
| using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>; | |||||
| class Map : public MetaFuncGraph { | |||||
| public: | |||||
| explicit Map(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) | |||||
| : MetaFuncGraph("map"), | |||||
| fn_leaf_(fn_leaf), | |||||
| broadcast_(false), | |||||
| nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { | |||||
| Init(); | |||||
| } | |||||
| Map(const Map &h) : MetaFuncGraph("map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { | |||||
| Init(); | |||||
| } | |||||
| Map &operator=(const Map &h) { | |||||
| if (this != &h) { | |||||
| fn_leaf_ = h.fn_leaf_; | |||||
| broadcast_ = h.broadcast_; | |||||
| nonleaf_ = h.nonleaf_; | |||||
| if (fn_leaf_) { | |||||
| name_ = "map[" + fn_leaf_->name() + "]"; | |||||
| } | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| ~Map() override = default; | |||||
| MS_DECLARE_PARENT(Map, MetaFuncGraph) | |||||
| abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; | |||||
| FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; | |||||
| MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } | |||||
| private: | |||||
| FuncGraphPtr GenerateLeafFunc(const size_t &args_size); | |||||
| AnfNodePtr FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args); | |||||
| AnfNodePtr FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||||
| const ArgsPairList &arg_pairs); | |||||
| AnfNodePtr FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||||
| const ArgsPairList &arg_pairs); | |||||
| AnfNodePtr FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||||
| const ArgsPairList &arg_pairs); | |||||
| AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs); | |||||
| void Init() { | |||||
| if (fn_leaf_ != nullptr) { | |||||
| name_ = "map[" + fn_leaf_->name() + "]"; | |||||
| } | |||||
| signatures_ = | |||||
| // def map(func:read, *args:ref): | |||||
| std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, | |||||
| {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); | |||||
| } | |||||
| MultitypeFuncGraphPtr fn_leaf_; | |||||
| bool broadcast_; | |||||
| std::set<TypeId> nonleaf_; | |||||
| }; | |||||
| using MapPtr = std::shared_ptr<Map>; | |||||
| class MapPy : public Map { | |||||
| public: | |||||
| explicit MapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : Map(fn_leaf) {} | |||||
| ~MapPy() override = default; | |||||
| MS_DECLARE_PARENT(MapPy, Map) | |||||
| }; | |||||
| using MapPyPtr = std::shared_ptr<MapPy>; | |||||
| } // namespace prim | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ | |||||
| @@ -14,9 +14,14 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <string> | |||||
| #include <sstream> | |||||
| #include "ir/dtype.h" | |||||
| #include "common/utils.h" | |||||
| #include "operator/ops.h" | |||||
| #include "pipeline/static_analysis/param_validator.h" | #include "pipeline/static_analysis/param_validator.h" | ||||
| #include "pipeline/static_analysis/prim.h" | #include "pipeline/static_analysis/prim.h" | ||||
| #include "operator/ops.h" | |||||
| #include "pipeline/static_analysis/utils.h" | #include "pipeline/static_analysis/utils.h" | ||||
| #include "utils/symbolic.h" | #include "utils/symbolic.h" | ||||
| @@ -50,6 +55,65 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit | |||||
| return AbstractFunction::MakeAbstractFunction(jv); | return AbstractFunction::MakeAbstractFunction(jv); | ||||
| } | } | ||||
| class UndeterminedShapeType { | |||||
| public: | |||||
| explicit UndeterminedShapeType(const std::string &env_str) { | |||||
| // param_name indices_shape indices_type values_shape values_type dense_shape | |||||
| // export UNDETERMINED_SPARSE_SHAPE_TYPES="w1:2:Int32:2 1 2:Float32:3 1 2" | |||||
| std::vector<string> fields; | |||||
| string tmp; | |||||
| std::stringstream input(env_str); | |||||
| while (std::getline(input, tmp, ':')) { | |||||
| fields.push_back(tmp); | |||||
| } | |||||
| if (fields.size() != fields_num) { | |||||
| MS_LOG(EXCEPTION) << "Expect " << fields_num << " fields, but got " << fields.size(); | |||||
| } | |||||
| param_name_ = fields[0]; | |||||
| indices_shape_ = GetShape(fields[1]); | |||||
| indices_type_ = StringToType(fields[2]); | |||||
| values_shape_ = GetShape(fields[3]); | |||||
| values_type_ = StringToType(fields[4]); | |||||
| auto dense_shape_vec = GetShape(fields[5]); | |||||
| AbstractBasePtrList dense_shape_list; | |||||
| (void)std::transform(dense_shape_vec.begin(), dense_shape_vec.end(), std::back_inserter(dense_shape_list), | |||||
| [](const auto &elem) { return FromValue(elem, false); }); | |||||
| dense_shape_ = dense_shape_list; | |||||
| } | |||||
| const std::string ¶m_name() { return param_name_; } | |||||
| const std::vector<int> &indices_shape() { return indices_shape_; } | |||||
| const TypePtr &indices_type() { return indices_type_; } | |||||
| const std::vector<int> &values_shape() { return values_shape_; } | |||||
| const TypePtr &values_type() { return values_type_; } | |||||
| const AbstractBasePtrList &dense_shape() { return dense_shape_; } | |||||
| private: | |||||
| std::string param_name_; | |||||
| std::vector<int> indices_shape_; | |||||
| TypePtr indices_type_; | |||||
| std::vector<int> values_shape_; | |||||
| TypePtr values_type_; | |||||
| AbstractBasePtrList dense_shape_; | |||||
| static const size_t fields_num; | |||||
| std::vector<int> GetShape(const std::string &shape_str); | |||||
| }; | |||||
| std::vector<int> UndeterminedShapeType::GetShape(const std::string &shape_str) { | |||||
| std::vector<int> ret; | |||||
| std::istringstream iss(shape_str); | |||||
| int elem; | |||||
| while (iss.good()) { | |||||
| iss >> elem; | |||||
| ret.emplace_back(elem); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| const size_t UndeterminedShapeType::fields_num = 6; | |||||
| AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| @@ -62,6 +126,31 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt | |||||
| if (type->type_id() != kObjectTypeSymbolicKeyType) { | if (type->type_id() != kObjectTypeSymbolicKeyType) { | ||||
| MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString(); | MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString(); | ||||
| } | } | ||||
| if (key->sparse_grad()) { | |||||
| // Will be fixed once undetermined type ready | |||||
| auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES"); | |||||
| if (sparse_shape_types.empty()) { | |||||
| sparse_shape_types = "w1:2:Int32:2 1 2:Float32:3 1 2"; | |||||
| } | |||||
| MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString() << ", Undetermined shape is " | |||||
| << sparse_shape_types; | |||||
| auto shape_types = UndeterminedShapeType(sparse_shape_types); | |||||
| AbstractBasePtrList sparse_list; | |||||
| // indices | |||||
| auto indices_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types.indices_type()); | |||||
| auto indices = std::make_shared<AbstractTensor>(indices_ele, std::make_shared<Shape>(shape_types.indices_shape())); | |||||
| sparse_list.emplace_back(indices); | |||||
| // values | |||||
| auto dout_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types.values_type()); | |||||
| auto dout = std::make_shared<AbstractTensor>(dout_ele, std::make_shared<Shape>(shape_types.values_shape())); | |||||
| sparse_list.emplace_back(dout); | |||||
| // dense_shape | |||||
| sparse_list.emplace_back(std::make_shared<AbstractTuple>(shape_types.dense_shape())); | |||||
| return std::make_shared<AbstractTuple>(sparse_list); | |||||
| } | |||||
| if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) { | if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) { | ||||
| return dflt; | return dflt; | ||||
| } | } | ||||
| @@ -80,8 +169,6 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 3); | CheckArgsSize(primitive->name(), args_spec_list, 3); | ||||
| auto key = args_spec_list[1]; | auto key = args_spec_list[1]; | ||||
| auto value = args_spec_list[2]; | |||||
| ValuePtr key_value_ptr = key->GetValueTrack(); | ValuePtr key_value_ptr = key->GetValueTrack(); | ||||
| MS_EXCEPTION_IF_NULL(key_value_ptr); | MS_EXCEPTION_IF_NULL(key_value_ptr); | ||||
| auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>(); | auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>(); | ||||
| @@ -91,7 +178,6 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt | |||||
| } | } | ||||
| auto expected = key_value_track->abstract(); | auto expected = key_value_track->abstract(); | ||||
| MS_EXCEPTION_IF_NULL(expected); | MS_EXCEPTION_IF_NULL(expected); | ||||
| (void)expected->Join(value); | |||||
| return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | ||||
| } | } | ||||
| @@ -126,7 +212,9 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr & | |||||
| if (type->type_id() != kObjectTypeRefKey) { | if (type->type_id() != kObjectTypeRefKey) { | ||||
| MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); | MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); | ||||
| } | } | ||||
| return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); | |||||
| auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); | |||||
| ret->set_sparse_grad(args_spec_list[2]->sparse_grad()); | |||||
| return ret; | |||||
| } | } | ||||
| AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, | AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, | ||||
| @@ -38,6 +38,7 @@ | |||||
| #include "pipeline/remove_value_node_dup.h" | #include "pipeline/remove_value_node_dup.h" | ||||
| #include "optimizer/optimizer.h" | #include "optimizer/optimizer.h" | ||||
| #include "vm/transform.h" | #include "vm/transform.h" | ||||
| #include "parse/python_adapter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace pipeline { | namespace pipeline { | ||||
| @@ -228,6 +229,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||||
| if (param_node->has_default()) { | if (param_node->has_default()) { | ||||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param()); | auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param()); | ||||
| AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true); | AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true); | ||||
| auto sparse_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "sparse_grad")); | |||||
| ptr->set_sparse_grad(sparse_grad); | |||||
| parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); | parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); | ||||
| args_spec.push_back(ptr); | args_spec.push_back(ptr); | ||||
| @@ -51,6 +51,7 @@ ValuePtr AbstractBase::BuildValue() const { | |||||
| AbstractBasePtr AbstractBase::Broaden() const { | AbstractBasePtr AbstractBase::Broaden() const { | ||||
| AbstractBasePtr clone = Clone(); | AbstractBasePtr clone = Clone(); | ||||
| clone->set_value(kAnyValue); | clone->set_value(kAnyValue); | ||||
| clone->set_sparse_grad(sparse_grad_); | |||||
| return clone; | return clone; | ||||
| } | } | ||||
| @@ -63,7 +64,8 @@ std::string AbstractBase::ToString() const { | |||||
| MS_EXCEPTION_IF_NULL(type_); | MS_EXCEPTION_IF_NULL(type_); | ||||
| MS_EXCEPTION_IF_NULL(shape_); | MS_EXCEPTION_IF_NULL(shape_); | ||||
| buffer << type_name() << "(" | buffer << type_name() << "(" | ||||
| << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << ")"; | |||||
| << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() | |||||
| << " sparse_grad: " << sparse_grad_ << ")"; | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| @@ -72,16 +74,22 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden() | |||||
| AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { | AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { | ||||
| MS_EXCEPTION_IF_NULL(other); | MS_EXCEPTION_IF_NULL(other); | ||||
| if (*this == *other) { | if (*this == *other) { | ||||
| return shared_from_base<AbstractBase>(); | |||||
| auto ret = shared_from_base<AbstractBase>(); | |||||
| ret->set_sparse_grad(sparse_grad()); | |||||
| return ret; | |||||
| } | } | ||||
| auto value_self = GetValueTrack(); | auto value_self = GetValueTrack(); | ||||
| MS_EXCEPTION_IF_NULL(value_self); | MS_EXCEPTION_IF_NULL(value_self); | ||||
| ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); | ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); | ||||
| TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack()); | TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack()); | ||||
| if (res_value == value_self) { | if (res_value == value_self) { | ||||
| return shared_from_base<AbstractBase>(); | |||||
| auto ret = shared_from_base<AbstractBase>(); | |||||
| ret->set_sparse_grad(sparse_grad()); | |||||
| return ret; | |||||
| } | } | ||||
| return std::make_shared<AbstractScalar>(res_value, res_type); | |||||
| auto ret = std::make_shared<AbstractScalar>(res_value, res_type); | |||||
| ret->set_sparse_grad(sparse_grad()); | |||||
| return ret; | |||||
| } | } | ||||
| AbstractBasePtr AbstractType::Clone() const { | AbstractBasePtr AbstractType::Clone() const { | ||||
| @@ -423,7 +431,9 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { | |||||
| } | } | ||||
| auto element = element_->Join(other_tensor->element_); | auto element = element_->Join(other_tensor->element_); | ||||
| auto shape = ShapeJoin(this->shape(), other_tensor->shape()); | auto shape = ShapeJoin(this->shape(), other_tensor->shape()); | ||||
| return std::make_shared<AbstractTensor>(element, shape); | |||||
| auto ret = std::make_shared<AbstractTensor>(element, shape); | |||||
| ret->set_sparse_grad(sparse_grad()); | |||||
| return ret; | |||||
| } | } | ||||
| bool AbstractTensor::operator==(const AbstractTensor &other) const { | bool AbstractTensor::operator==(const AbstractTensor &other) const { | ||||
| @@ -463,6 +473,7 @@ AbstractBasePtr AbstractTensor::Clone() const { | |||||
| ShapePtr shp = shape(); | ShapePtr shp = shape(); | ||||
| clone->set_shape(shp->Clone()); | clone->set_shape(shp->Clone()); | ||||
| clone->set_value(GetValueTrack()); | clone->set_value(GetValueTrack()); | ||||
| clone->set_sparse_grad(sparse_grad()); | |||||
| return clone; | return clone; | ||||
| } | } | ||||
| @@ -472,6 +483,7 @@ AbstractBasePtr AbstractTensor::Broaden() const { | |||||
| auto shp = shape(); | auto shp = shape(); | ||||
| broaden->set_shape(shp->Clone()); | broaden->set_shape(shp->Clone()); | ||||
| broaden->set_value(kAnyValue); | broaden->set_value(kAnyValue); | ||||
| broaden->set_sparse_grad(sparse_grad()); | |||||
| return broaden; | return broaden; | ||||
| } | } | ||||
| @@ -482,6 +494,7 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const { | |||||
| shp->Broaden(); | shp->Broaden(); | ||||
| broaden->set_shape(shp); | broaden->set_shape(shp); | ||||
| broaden->set_value(kAnyValue); | broaden->set_value(kAnyValue); | ||||
| broaden->set_sparse_grad(sparse_grad()); | |||||
| return broaden; | return broaden; | ||||
| } | } | ||||
| @@ -502,7 +515,8 @@ std::string AbstractTensor::ToString() const { | |||||
| MS_EXCEPTION_IF_NULL(value_track); | MS_EXCEPTION_IF_NULL(value_track); | ||||
| buffer << type_name() << "(" | buffer << type_name() << "(" | ||||
| << "shape: " << shape_track->ToString() << ", element: " << element_->ToString() | << "shape: " << shape_track->ToString() << ", element: " << element_->ToString() | ||||
| << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"; | |||||
| << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad() | |||||
| << ")"; | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| @@ -44,7 +44,7 @@ class AbstractBase : public Base { | |||||
| public: | public: | ||||
| explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, | explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, | ||||
| const BaseShapePtr &shape = kNoShape) | const BaseShapePtr &shape = kNoShape) | ||||
| : value_(value), type_(type), shape_(shape) {} | |||||
| : value_(value), type_(type), shape_(shape), sparse_grad_(false) {} | |||||
| ~AbstractBase() override = default; | ~AbstractBase() override = default; | ||||
| MS_DECLARE_PARENT(AbstractBase, Base) | MS_DECLARE_PARENT(AbstractBase, Base) | ||||
| @@ -53,11 +53,13 @@ class AbstractBase : public Base { | |||||
| virtual bool operator==(const AbstractBase &other) const; | virtual bool operator==(const AbstractBase &other) const; | ||||
| void set_value(const ValuePtr &value) { value_ = value; } | void set_value(const ValuePtr &value) { value_ = value; } | ||||
| void set_sparse_grad(const bool &sparse_grad) { sparse_grad_ = sparse_grad; } | |||||
| void set_type(const TypePtr &type) { type_ = type; } | void set_type(const TypePtr &type) { type_ = type; } | ||||
| void set_shape(const BaseShapePtr &shape) { shape_ = shape; } | void set_shape(const BaseShapePtr &shape) { shape_ = shape; } | ||||
| void set_value_desc(const std::string &desc) { value_desc_ = desc; } | void set_value_desc(const std::string &desc) { value_desc_ = desc; } | ||||
| const std::string &value_desc() const { return value_desc_; } | const std::string &value_desc() const { return value_desc_; } | ||||
| ValuePtr GetValueTrack() const { return value_; } | ValuePtr GetValueTrack() const { return value_; } | ||||
| bool sparse_grad() const { return sparse_grad_; } | |||||
| TypePtr GetTypeTrack() const { return type_; } | TypePtr GetTypeTrack() const { return type_; } | ||||
| BaseShapePtr GetShapeTrack() const { return shape_; } | BaseShapePtr GetShapeTrack() const { return shape_; } | ||||
| @@ -85,6 +87,7 @@ class AbstractBase : public Base { | |||||
| TypePtr type_; | TypePtr type_; | ||||
| BaseShapePtr shape_; | BaseShapePtr shape_; | ||||
| std::string value_desc_; // store initial value description for error report | std::string value_desc_; // store initial value description for error report | ||||
| bool sparse_grad_; | |||||
| }; | }; | ||||
| class AbstractScalar : public AbstractBase { | class AbstractScalar : public AbstractBase { | ||||
| @@ -851,7 +851,11 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||||
| } | } | ||||
| auto refkey = key_value->cast<RefKeyPtr>(); | auto refkey = key_value->cast<RefKeyPtr>(); | ||||
| if (refkey == nullptr) { | if (refkey == nullptr) { | ||||
| return std::make_shared<EvalResult>(std::make_shared<AbstractScalar>(type), std::make_shared<AttrValueMap>()); | |||||
| auto ret = std::make_shared<AbstractScalar>(type); | |||||
| auto ref_value = ref_abs->ref(); | |||||
| MS_EXCEPTION_IF_NULL(ref_value); | |||||
| ret->set_sparse_grad(ref_value->sparse_grad()); | |||||
| return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>()); | |||||
| } | } | ||||
| std::string name = refkey->tag(); | std::string name = refkey->tag(); | ||||
| @@ -865,6 +869,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||||
| x = SensitivityTransform(x); | x = SensitivityTransform(x); | ||||
| std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x); | std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x); | ||||
| std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type); | std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type); | ||||
| abs_scalar->set_sparse_grad(x->sparse_grad()); | |||||
| return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>()); | return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>()); | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -50,12 +50,14 @@ class Parameter: | |||||
| requires_grad (bool): True if the parameter requires gradient. Default: True. | requires_grad (bool): True if the parameter requires gradient. Default: True. | ||||
| layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode, | layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode, | ||||
| broadcast and gradients communication would not be applied on parameters. Default: False. | broadcast and gradients communication would not be applied on parameters. Default: False. | ||||
| sparse_grad (bool): True if the parameter's gradient is sparse. Default: False. | |||||
| """ | """ | ||||
| def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False): | |||||
| def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, sparse_grad=False): | |||||
| self.set_parameter_data(default_input) | self.set_parameter_data(default_input) | ||||
| self.name = name | self.name = name | ||||
| self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
| self.layerwise_parallel = layerwise_parallel | self.layerwise_parallel = layerwise_parallel | ||||
| self.sparse_grad = sparse_grad | |||||
| self._is_init = False | self._is_init = False | ||||
| self._sliced = False | self._sliced = False | ||||
| self.clone_info = _CloneInfo() | self.clone_info = _CloneInfo() | ||||
| @@ -168,6 +170,17 @@ class Parameter: | |||||
| raise TypeError("`requires_grad` parameter must be bool type") | raise TypeError("`requires_grad` parameter must be bool type") | ||||
| self._requires_grad = value | self._requires_grad = value | ||||
| @property | |||||
| def sparse_grad(self): | |||||
| """Return whether the parameter's gradient is sparse.""" | |||||
| return self._sparse_grad | |||||
| @sparse_grad.setter | |||||
| def sparse_grad(self, value=True): | |||||
| if not isinstance(value, bool): | |||||
| raise TypeError("`sparse_grad` parameter must be bool type") | |||||
| self._sparse_grad = value | |||||
| @property | @property | ||||
| def data(self): | def data(self): | ||||
| return self.default_input | return self.default_input | ||||
| @@ -30,6 +30,7 @@ unsorted_segment_sum = P.UnsortedSegmentSum() | |||||
| transpose = P.Transpose() | transpose = P.Transpose() | ||||
| shape_op = P.Shape() | shape_op = P.Shape() | ||||
| reshape = P.Reshape() | reshape = P.Reshape() | ||||
| size_op = P.Size() | |||||
| invert_permutation = P.InvertPermutation() | invert_permutation = P.InvertPermutation() | ||||
| logical_and = P.LogicalAnd() | logical_and = P.LogicalAnd() | ||||
| @@ -284,6 +285,37 @@ def get_bprop_gather_v2(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.SparseGatherV2) | |||||
| def get_bprop_sparse_gather_v2(self): | |||||
| """Generate bprop for SparseGatherV2""" | |||||
| def bprop(x, indices, axis, out, dout): | |||||
| x_shp = shape_op(x) | |||||
| if axis == 0: | |||||
| indices_size = (size_op(indices),) | |||||
| x_tail_shp = x_shp[1:] | |||||
| values_shape = indices_size + x_tail_shp | |||||
| values = reshape(dout, values_shape) | |||||
| indices = reshape(indices, indices_size) | |||||
| return (indices, values, x_shp), zeros_like(indices), zeros_like(axis) | |||||
| if F.rank(dout) == 0: | |||||
| dout = P.ExpandDims()(dout, -1) | |||||
| if F.rank(indices) == 0: | |||||
| indices = P.ExpandDims()(indices, -1) | |||||
| out_shp = shape_op(dout) | |||||
| ind_shp = shape_op(indices) | |||||
| # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) | |||||
| perm_1 = _generate_shape_index(out_shp, ind_shp, axis) | |||||
| values_transpose = transpose(dout, perm_1) | |||||
| params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) | |||||
| # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) | |||||
| perm_2 = _generate_inverse_index(x_shp, axis) | |||||
| params_grad = transpose(params_grad, perm_2) | |||||
| return params_grad, zeros_like(indices), zeros_like(axis) | |||||
| return bprop | |||||
| @bprop_getters.register(P.Range) | @bprop_getters.register(P.Range) | ||||
| def get_bprop_range(self): | def get_bprop_range(self): | ||||
| """Generate bprop for Range""" | """Generate bprop for Range""" | ||||
| @@ -20,7 +20,7 @@ Pre-defined combination of operators. | |||||
| """ | """ | ||||
| from .base import GradOperation, HyperMap, MultitypeFuncGraph, add_flags, \ | |||||
| from .base import GradOperation, HyperMap, Map, MultitypeFuncGraph, add_flags, \ | |||||
| grad, grad_all, grad_all_with_sens, grad_by_list, grad_by_list_with_sens, grad_with_sens, \ | grad, grad_all, grad_all_with_sens, grad_by_list, grad_by_list_with_sens, grad_with_sens, \ | ||||
| core, env_get, tail, zip_operation | core, env_get, tail, zip_operation | ||||
| from .clip_ops import clip_by_value | from .clip_ops import clip_by_value | ||||
| @@ -19,7 +19,7 @@ | |||||
| from functools import partial | from functools import partial | ||||
| from mindspore import context | from mindspore import context | ||||
| from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \ | |||||
| from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, TensorSlice_, \ | |||||
| TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ | TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.api import ms_function, _pynative_exec | from ...common.api import ms_function, _pynative_exec | ||||
| @@ -240,6 +240,69 @@ class HyperMap(HyperMap_): | |||||
| return func(*args_list) | return func(*args_list) | ||||
| return tuple(map(hypermap, *args_list)) | return tuple(map(hypermap, *args_list)) | ||||
| class Map(Map_): | |||||
| """ | |||||
| Map will apply the set operation on input sequences. | |||||
| Which will apply the operations of every elements of the sequence. | |||||
| Args: | |||||
| ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`, | |||||
| the operations should be putted in the first input of the instance. | |||||
| Inputs: | |||||
| - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences, | |||||
| and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence | |||||
| `(args[0][i], args[1][i])` will be the input of the operation. | |||||
| If `ops` is not `None`, the first input is the operation, and the other is inputs. | |||||
| Outputs: | |||||
| sequence, the output will be same type and same length of sequence from input and the value of each element | |||||
| is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`. | |||||
| """ | |||||
| def __init__(self, ops=None): | |||||
| self.ops = ops | |||||
| if ops: | |||||
| Map_.__init__(self, ops) | |||||
| else: | |||||
| Map_.__init__(self) | |||||
| def __call__(self, *args): | |||||
| func = args[0] | |||||
| count = 0 | |||||
| count_max = 1 | |||||
| args_list = args[1:] | |||||
| if self.ops is not None: | |||||
| func = self.ops | |||||
| args_list = args | |||||
| for item in args_list: | |||||
| if isinstance(item, (tuple, list)): | |||||
| count_max = len(item) | |||||
| break | |||||
| def get_item(x): | |||||
| nonlocal count | |||||
| if isinstance(x, (tuple, list)): | |||||
| return x[count] | |||||
| return x | |||||
| for i in range(count_max): | |||||
| true_args = tuple(map(get_item, args_list)) | |||||
| func(*true_args) | |||||
| count = i + 1 | |||||
| return True | |||||
| def register(self, *type_names): | |||||
| """Register a function for the given type string.""" | |||||
| def deco(fn): | |||||
| self.register_fn(type_names, fn) | |||||
| return fn | |||||
| return deco | |||||
| class _ListAppend(ListAppend_): | class _ListAppend(ListAppend_): | ||||
| """ | """ | ||||
| A metafuncgraph class that append one element to list. | A metafuncgraph class that append one element to list. | ||||
| @@ -21,7 +21,7 @@ A collection of operators to build nerual networks or computing functions. | |||||
| from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | ||||
| Diag, DiagPart, DType, ExpandDims, Eye, | Diag, DiagPart, DType, ExpandDims, Eye, | ||||
| Fill, GatherNd, GatherV2, InvertPermutation, | |||||
| Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, | |||||
| IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, | IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, | ||||
| Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range, | Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range, | ||||
| SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, | SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, | ||||
| @@ -122,6 +122,7 @@ __all__ = [ | |||||
| 'Transpose', | 'Transpose', | ||||
| 'OneHot', | 'OneHot', | ||||
| 'GatherV2', | 'GatherV2', | ||||
| 'SparseGatherV2', | |||||
| 'Concat', | 'Concat', | ||||
| 'Pack', | 'Pack', | ||||
| 'Unpack', | 'Unpack', | ||||
| @@ -526,6 +526,29 @@ class GatherV2(PrimitiveWithInfer): | |||||
| return out | return out | ||||
| class SparseGatherV2(GatherV2): | |||||
| """ | |||||
| Returns a slice of input tensor based on the specified indices and axis. | |||||
| Inputs: | |||||
| - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||||
| The original Tensor. | |||||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | |||||
| Specifies the indices of elements of the original Tensor. Must be in the range | |||||
| `[0, input_param.shape()[axis])`. | |||||
| - **axis** (int) - Specifies the dimension index to gather indices. | |||||
| Outputs: | |||||
| Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. | |||||
| Examples: | |||||
| >>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32) | |||||
| >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32) | |||||
| >>> axis = 1 | |||||
| >>> out = P.GatherV2()(input_params, input_indices, axis) | |||||
| """ | |||||
| class Range(PrimitiveWithInfer): | class Range(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Creates a sequence of numbers. | Creates a sequence of numbers. | ||||
| @@ -332,6 +332,8 @@ class CheckBprop(PrimitiveWithInfer): | |||||
| def infer_shape(self, xshapes, yshapes): | def infer_shape(self, xshapes, yshapes): | ||||
| tips = f'Bprop of {self.prim_to_check}' | tips = f'Bprop of {self.prim_to_check}' | ||||
| validator.check_value_type('grads', xshapes, (tuple,), tips) | |||||
| validator.check_value_type('params', yshapes, (tuple,), tips) | |||||
| if len(xshapes) < len(yshapes): | if len(xshapes) < len(yshapes): | ||||
| raise TypeError(f"{tips}, the size of output should be {len(yshapes)}," | raise TypeError(f"{tips}, the size of output should be {len(yshapes)}," | ||||
| f" but got {len(xshapes)}.") | f" but got {len(xshapes)}.") | ||||
| @@ -348,6 +350,8 @@ class CheckBprop(PrimitiveWithInfer): | |||||
| def infer_dtype(self, xdtypes, ydtypes): | def infer_dtype(self, xdtypes, ydtypes): | ||||
| tips = f'Bprop of {self.prim_to_check}' | tips = f'Bprop of {self.prim_to_check}' | ||||
| validator.check_value_type('grads', xdtypes, (tuple,), tips) | |||||
| validator.check_value_type('params', ydtypes, (tuple,), tips) | |||||
| if len(xdtypes) < len(ydtypes): | if len(xdtypes) < len(ydtypes): | ||||
| raise TypeError(f"{tips}, the size of output should be {len(ydtypes)}," | raise TypeError(f"{tips}, the size of output should be {len(ydtypes)}," | ||||
| f" but got {len(xdtypes)}.") | f" but got {len(xdtypes)}.") | ||||
| @@ -0,0 +1,173 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test adam """ | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor, Parameter, context | |||||
| from mindspore.common.api import _executor | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||||
| from mindspore.nn.optim import Optimizer | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._checkparam import Rel | |||||
| adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") | |||||
| @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||||
| "Tensor", "Tensor", "Tensor", "Bool") | |||||
| def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): | |||||
| op_mul = P.Mul() | |||||
| op_square = P.Square() | |||||
| op_sqrt = P.Sqrt() | |||||
| op_cast = P.Cast() | |||||
| op_reshape = P.Reshape() | |||||
| op_shape = P.Shape() | |||||
| param_fp32 = op_cast(param, mstype.float32) | |||||
| m_fp32 = op_cast(m, mstype.float32) | |||||
| v_fp32 = op_cast(v, mstype.float32) | |||||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||||
| next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) | |||||
| next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) | |||||
| - beta2, op_square(gradient_fp32)) | |||||
| update = next_m / (op_sqrt(next_v) + eps) | |||||
| if decay_flag: | |||||
| update = update + op_mul(weight_decay_tensor, param_fp32) | |||||
| update_with_lr = op_mul(lr, update) | |||||
| next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) | |||||
| next_v = F.depend(next_v, F.assign(param, next_param)) | |||||
| next_v = F.depend(next_v, F.assign(m, next_m)) | |||||
| next_v = F.depend(next_v, F.assign(v, next_v)) | |||||
| return next_v | |||||
| @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||||
| "Tensor", "Tensor", "Tuple", "Bool") | |||||
| def _update_run_op_sparse_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): | |||||
| return gradient[2][2] | |||||
| def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): | |||||
| """Check the type of inputs.""" | |||||
| validator.check_value_type("beta1", beta1, [float], prim_name) | |||||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||||
| validator.check_value_type("eps", eps, [float], prim_name) | |||||
| validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | |||||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||||
| class AdamWeightDecaySparse(Optimizer): | |||||
| """ | |||||
| Implements Adam algorithm weight decay fix. | |||||
| Args: | |||||
| params (list[Parameter]): A list of parameter, which will be updated. The element in `params` | |||||
| should be class mindspore.Parameter. | |||||
| learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is | |||||
| Iterable or a Tensor and the dims of the Tensor is 1, | |||||
| use dynamic learning rate, then the i-th step will | |||||
| take the i-th value as the learning rate. | |||||
| When the learning_rate is float or learning_rate is a Tensor | |||||
| but the dims of the Tensor is 0, use fixed learning rate. | |||||
| Other cases are not supported. Default: 1e-3. | |||||
| beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9. | |||||
| Should be in range (0.0, 1.0). | |||||
| beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999. | |||||
| Should be in range (0.0, 1.0). | |||||
| eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. | |||||
| Should be greater than 0. | |||||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | |||||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | |||||
| lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. | |||||
| Inputs: | |||||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`, | |||||
| and might be in sparse format. | |||||
| Outputs: | |||||
| tuple[Parameter], the updated velocity value, the shape is the same as `params`. | |||||
| Examples: | |||||
| >>> net = Net() | |||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| >>> optim = nn.AdamWeightDecay(params=net.trainable_params()) | |||||
| >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | |||||
| """ | |||||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, | |||||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||||
| super(AdamWeightDecaySparse, self).__init__(learning_rate, params) | |||||
| if self.is_group: | |||||
| raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") | |||||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||||
| self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) | |||||
| self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) | |||||
| self.eps = Tensor(np.array([eps]).astype(np.float32)) | |||||
| self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) | |||||
| self.params = self.parameters | |||||
| self.moments1 = self.params.clone(prefix="adam_m", init='zeros') | |||||
| self.moments2 = self.params.clone(prefix="adam_v", init='zeros') | |||||
| self.decay_flag = tuple(decay_filter(x) for x in self.params) | |||||
| self.map = C.Map() | |||||
| def construct(self, gradients): | |||||
| lr = self.get_lr() | |||||
| updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr, | |||||
| self.weight_decay_tensor), | |||||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||||
| return updated_velocity | |||||
| def test_AdamWeightDecaySparse(): | |||||
| """ test_AdamWeightDecaySparse """ | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| class Loss(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Loss, self).__init__() | |||||
| def construct(self, base, target): | |||||
| return base | |||||
| class NetWithSparseGatherV2(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetWithSparseGatherV2, self).__init__() | |||||
| self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad=True) | |||||
| self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") | |||||
| self.gatherv2 = P.SparseGatherV2() | |||||
| self.axis = 0 | |||||
| def construct(self, indices): | |||||
| return self.gatherv2(self.w1, indices, self.axis) * self.w2 | |||||
| inputs = Tensor(np.array([0, 1]).astype(np.int32)) | |||||
| label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) | |||||
| net = NetWithSparseGatherV2() | |||||
| net.set_train() | |||||
| loss = Loss() | |||||
| optimizer = AdamWeightDecaySparse(net.trainable_params()) | |||||
| net_with_loss = WithLossCell(net, loss) | |||||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||||
| _executor.compile(train_network, inputs, label) | |||||