| @@ -30,6 +30,7 @@ | |||
| #include "pipeline/parse/python_adapter.h" | |||
| #include "pipeline/parse/resolve.h" | |||
| #include "operator/composite/composite.h" | |||
| #include "operator/composite/map.h" | |||
| #include "utils/ordered_map.h" | |||
| #include "utils/ordered_set.h" | |||
| #include "utils/utils.h" | |||
| @@ -190,6 +191,8 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap | |||
| * ├── MultitypeGraph | |||
| * ├── HyperMap | |||
| * │ └── HyperMapPy | |||
| * ├── Map | |||
| * │ └── MapPy | |||
| * ├── Tail | |||
| * ├── MakeTupleGradient | |||
| * ├── GradOperation | |||
| @@ -208,17 +211,25 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_ | |||
| oss << GetMultitypeFuncGraphText(mt_func_graph); | |||
| } else if (meta_func_graph | |||
| ->isa<prim::HyperMapPy>()) { // this statement must before 'meta_graph->isa<prim::HyperMap>()' | |||
| prim::HyperMapPyPtr hyper_map = meta_func_graph->cast<prim::HyperMapPyPtr>(); | |||
| MS_EXCEPTION_IF_NULL(hyper_map); | |||
| auto hyper_map = meta_func_graph->cast<prim::HyperMapPyPtr>(); | |||
| if (hyper_map->GetFnLeaf() != nullptr) { | |||
| oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}"; | |||
| } | |||
| } else if (meta_func_graph->isa<prim::HyperMap>()) { | |||
| prim::HyperMapPtr hyper_map = meta_func_graph->cast<prim::HyperMapPtr>(); | |||
| MS_EXCEPTION_IF_NULL(hyper_map); | |||
| auto hyper_map = meta_func_graph->cast<prim::HyperMapPtr>(); | |||
| if (hyper_map->GetFnLeaf() != nullptr) { | |||
| oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}"; | |||
| } | |||
| } else if (meta_func_graph->isa<prim::MapPy>()) { // this statement must before 'meta_graph->isa<prim::Map>()' | |||
| auto map = meta_func_graph->cast<prim::MapPyPtr>(); | |||
| if (map->GetFnLeaf() != nullptr) { | |||
| oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}"; | |||
| } | |||
| } else if (meta_func_graph->isa<prim::Map>()) { | |||
| auto map = meta_func_graph->cast<prim::MapPtr>(); | |||
| if (map->GetFnLeaf() != nullptr) { | |||
| oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}"; | |||
| } | |||
| } else if (meta_func_graph->isa<prim::GradOperation>()) { | |||
| prim::GradOperationPtr grad_op = meta_func_graph->cast<prim::GradOperationPtr>(); | |||
| oss << "{get_all=" << grad_op->get_all_ << ", get_by_list=" << grad_op->get_by_list_ | |||
| @@ -0,0 +1,289 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "operator/composite/map.h" | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "pipeline/static_analysis/abstract_function.h" | |||
| #include "pipeline/static_analysis/dshape.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "debug/trace.h" | |||
| #include "operator/ops.h" | |||
| #include "./common.h" | |||
| namespace mindspore { | |||
| // namespace to support composite operators definition | |||
| namespace prim { | |||
| using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure; | |||
| AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) { | |||
| MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n"; | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> inputs; | |||
| if (fn_arg != nullptr) { | |||
| inputs.emplace_back(fn_arg); | |||
| } else { | |||
| inputs.emplace_back(NewValueNode(fn_leaf_)); | |||
| } | |||
| inputs.insert(inputs.end(), args.begin(), args.end()); | |||
| return func_graph->NewCNode(inputs); | |||
| } | |||
| FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) { | |||
| // Generate func for leaf nodes | |||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | |||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||
| ptrGraph->debug_info()->set_name("map"); | |||
| AnfNodePtr ptrFnArg = nullptr; | |||
| if (fn_leaf_ == nullptr) { | |||
| ptrFnArg = ptrGraph->add_parameter(); | |||
| } | |||
| AnfNodePtrList args; | |||
| for (size_t i = 0; i < args_size; ++i) { | |||
| args.emplace_back(ptrGraph->add_parameter()); | |||
| } | |||
| ptrGraph->set_output(FullMakeLeaf(ptrGraph, ptrFnArg, args)); | |||
| return ptrGraph; | |||
| } | |||
| AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, | |||
| const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| std::size_t size = type->elements().size(); | |||
| bool is_not_same = | |||
| std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) { | |||
| auto lhs = std::dynamic_pointer_cast<List>(item.second); | |||
| MS_EXCEPTION_IF_NULL(lhs); | |||
| return lhs->elements().size() != size; | |||
| }); | |||
| if (is_not_same) { | |||
| MS_LOG(EXCEPTION) << "List in Map should have same length"; | |||
| } | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.push_back(NewValueNode(prim::kPrimMakeList)); | |||
| for (int i = 0; i < SizeToInt(size); ++i) { | |||
| MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the target"; | |||
| auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); | |||
| auto fn = NewValueNode(ptrGraph); | |||
| std::vector<AnfNodePtr> inputs2; | |||
| inputs2.push_back(fn); | |||
| if (fn_arg != nullptr) { | |||
| inputs2.push_back(fn_arg); | |||
| } | |||
| (void)std::transform( | |||
| arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), | |||
| [&func_graph, i](const std::pair<AnfNodePtr, Any> &item) { | |||
| return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); | |||
| }); | |||
| inputs.push_back(func_graph->NewCNode(inputs2)); | |||
| } | |||
| return func_graph->NewCNode(inputs); | |||
| } | |||
| AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, | |||
| const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| std::size_t size = type->elements().size(); | |||
| bool is_not_same = | |||
| std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) { | |||
| auto lhs = std::dynamic_pointer_cast<Tuple>(item.second); | |||
| MS_EXCEPTION_IF_NULL(lhs); | |||
| return lhs->elements().size() != size; | |||
| }); | |||
| if (is_not_same) { | |||
| MS_LOG(EXCEPTION) << "tuple in Map should have same length"; | |||
| } | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| for (int i = 0; i < SizeToInt(size); ++i) { | |||
| MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the tuple inputs"; | |||
| auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); | |||
| auto fn = NewValueNode(ptrGraph); | |||
| std::vector<AnfNodePtr> inputs2; | |||
| inputs2.push_back(fn); | |||
| if (fn_arg != nullptr) { | |||
| inputs2.push_back(fn_arg); | |||
| } | |||
| (void)std::transform( | |||
| arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), | |||
| [&func_graph, &i](std::pair<AnfNodePtr, Any> item) { | |||
| return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); | |||
| }); | |||
| inputs.push_back(func_graph->NewCNode(inputs2)); | |||
| } | |||
| return func_graph->NewCNode(inputs); | |||
| } | |||
| AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, | |||
| const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); | |||
| inputs.push_back(NewValueNode(type)); | |||
| std::size_t attrSize = type->GetAttributes().size(); | |||
| for (std::size_t i = 0; i < attrSize; ++i) { | |||
| MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th element of the inputs"; | |||
| auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); | |||
| auto fn = NewValueNode(ptrGraph); | |||
| std::vector<AnfNodePtr> inputs2; | |||
| inputs2.push_back(fn); | |||
| if (fn_arg != nullptr) { | |||
| inputs2.push_back(fn_arg); | |||
| } | |||
| int j = 0; | |||
| for (auto item : arg_pairs) { | |||
| inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); | |||
| j++; | |||
| } | |||
| inputs.push_back(func_graph->NewCNode(inputs2)); | |||
| } | |||
| return func_graph->NewCNode(inputs); | |||
| } | |||
| AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { | |||
| bool found = false; | |||
| TypeId id = kObjectTypeEnd; | |||
| std::pair<AnfNodePtr, TypePtr> pair; | |||
| for (auto &item : arg_pairs) { | |||
| pair = item; | |||
| MS_LOG(DEBUG) << "Map " << pair.second->ToString(); | |||
| id = item.second->type_id(); | |||
| if (nonleaf_.count(id)) { | |||
| found = true; | |||
| break; | |||
| } | |||
| } | |||
| if (found) { | |||
| // In a nonleaf situation, all arguments must have the same generic. | |||
| bool is_not_same = | |||
| std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) { | |||
| if (item.first != pair.first) { | |||
| return item.second->type_id() != pair.second->type_id(); | |||
| } | |||
| return false; | |||
| }); | |||
| if (is_not_same) { | |||
| std::ostringstream oss; | |||
| oss << "There are " << arg_pairs.size() << " inputs of `" << name_ << "`, corresponding type info:\n" | |||
| << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; | |||
| int idx = 0; | |||
| for (auto &item : arg_pairs) { | |||
| oss << ++idx << ": " << item.second->ToString() << "\n"; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Map cannot match up all input types of arguments.\n" | |||
| << oss.str() << pair.second->ToString() << "\n"; | |||
| } | |||
| } | |||
| switch (id) { | |||
| case kObjectTypeList: { | |||
| auto type = std::static_pointer_cast<List>(pair.second); | |||
| return FullMakeList(type, func_graph, fn_arg, arg_pairs); | |||
| } | |||
| case kObjectTypeTuple: { | |||
| auto type = std::static_pointer_cast<Tuple>(pair.second); | |||
| return FullMakeTuple(type, func_graph, fn_arg, arg_pairs); | |||
| } | |||
| case kObjectTypeClass: { | |||
| auto type = std::static_pointer_cast<Class>(pair.second); | |||
| return FullMakeClass(type, func_graph, fn_arg, arg_pairs); | |||
| } | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class " | |||
| << ", but got " << pair.second->ToString(); | |||
| } | |||
| } | |||
| FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) { | |||
| FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); | |||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); | |||
| ptrGraph->debug_info()->set_name("map"); | |||
| AnfNodePtr ptrFnArg = nullptr; | |||
| std::size_t i = 0; | |||
| if (fn_leaf_ == nullptr) { | |||
| ptrFnArg = ptrGraph->add_parameter(); | |||
| i = 1; | |||
| } | |||
| ArgsPairList arg_pairs; | |||
| std::size_t size = args_spec_list.size(); | |||
| for (; i < size; ++i) { | |||
| MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_spec_list[i]->ToString(); | |||
| arg_pairs.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); | |||
| } | |||
| ptrGraph->set_output(Make(ptrGraph, ptrFnArg, arg_pairs)); | |||
| return ptrGraph; | |||
| } | |||
| abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { | |||
| if (fn_leaf_ == nullptr) { | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| // Assert that map's function param does not contain free variables | |||
| if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) { | |||
| auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]); | |||
| auto func_graph = graph_func->func_graph(); | |||
| if (func_graph->parent() != nullptr) { | |||
| MS_LOG(EXCEPTION) << "Map don't support Closure with free variable yet."; | |||
| } | |||
| } | |||
| } | |||
| AbstractBasePtrList broadened; | |||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), | |||
| [](const AbstractBasePtr &arg) -> AbstractBasePtr { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| return arg->Broaden(); | |||
| }); | |||
| return broadened; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) { | |||
| (void)py::class_<MapPy, MetaFuncGraph, std::shared_ptr<MapPy>>(*m, "Map_") | |||
| .def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf")) | |||
| .def(py::init<>()); | |||
| })); | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,98 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ | |||
| #define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ | |||
| #include <memory> | |||
| #include <set> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "ir/dtype.h" | |||
| #include "ir/meta_func_graph.h" | |||
| #include "operator/composite/multitype_funcgraph.h" | |||
| namespace mindspore { | |||
| // namespace to support composite operators definition | |||
| namespace prim { | |||
| using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>; | |||
| class Map : public MetaFuncGraph { | |||
| public: | |||
| explicit Map(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) | |||
| : MetaFuncGraph("map"), | |||
| fn_leaf_(fn_leaf), | |||
| broadcast_(false), | |||
| nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { | |||
| Init(); | |||
| } | |||
| Map(const Map &h) : MetaFuncGraph("map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { | |||
| Init(); | |||
| } | |||
| Map &operator=(const Map &h) { | |||
| if (this != &h) { | |||
| fn_leaf_ = h.fn_leaf_; | |||
| broadcast_ = h.broadcast_; | |||
| nonleaf_ = h.nonleaf_; | |||
| if (fn_leaf_) { | |||
| name_ = "map[" + fn_leaf_->name() + "]"; | |||
| } | |||
| } | |||
| return *this; | |||
| } | |||
| ~Map() override = default; | |||
| MS_DECLARE_PARENT(Map, MetaFuncGraph) | |||
| abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; | |||
| FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; | |||
| MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } | |||
| private: | |||
| FuncGraphPtr GenerateLeafFunc(const size_t &args_size); | |||
| AnfNodePtr FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args); | |||
| AnfNodePtr FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||
| const ArgsPairList &arg_pairs); | |||
| AnfNodePtr FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||
| const ArgsPairList &arg_pairs); | |||
| AnfNodePtr FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, | |||
| const ArgsPairList &arg_pairs); | |||
| AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs); | |||
| void Init() { | |||
| if (fn_leaf_ != nullptr) { | |||
| name_ = "map[" + fn_leaf_->name() + "]"; | |||
| } | |||
| signatures_ = | |||
| // def map(func:read, *args:ref): | |||
| std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, | |||
| {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); | |||
| } | |||
| MultitypeFuncGraphPtr fn_leaf_; | |||
| bool broadcast_; | |||
| std::set<TypeId> nonleaf_; | |||
| }; | |||
| using MapPtr = std::shared_ptr<Map>; | |||
| class MapPy : public Map { | |||
| public: | |||
| explicit MapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : Map(fn_leaf) {} | |||
| ~MapPy() override = default; | |||
| MS_DECLARE_PARENT(MapPy, Map) | |||
| }; | |||
| using MapPyPtr = std::shared_ptr<MapPy>; | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ | |||
| @@ -14,9 +14,14 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <sstream> | |||
| #include "ir/dtype.h" | |||
| #include "common/utils.h" | |||
| #include "operator/ops.h" | |||
| #include "pipeline/static_analysis/param_validator.h" | |||
| #include "pipeline/static_analysis/prim.h" | |||
| #include "operator/ops.h" | |||
| #include "pipeline/static_analysis/utils.h" | |||
| #include "utils/symbolic.h" | |||
| @@ -50,6 +55,65 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit | |||
| return AbstractFunction::MakeAbstractFunction(jv); | |||
| } | |||
| class UndeterminedShapeType { | |||
| public: | |||
| explicit UndeterminedShapeType(const std::string &env_str) { | |||
| // param_name indices_shape indices_type values_shape values_type dense_shape | |||
| // export UNDETERMINED_SPARSE_SHAPE_TYPES="w1:2:Int32:2 1 2:Float32:3 1 2" | |||
| std::vector<string> fields; | |||
| string tmp; | |||
| std::stringstream input(env_str); | |||
| while (std::getline(input, tmp, ':')) { | |||
| fields.push_back(tmp); | |||
| } | |||
| if (fields.size() != fields_num) { | |||
| MS_LOG(EXCEPTION) << "Expect " << fields_num << " fields, but got " << fields.size(); | |||
| } | |||
| param_name_ = fields[0]; | |||
| indices_shape_ = GetShape(fields[1]); | |||
| indices_type_ = StringToType(fields[2]); | |||
| values_shape_ = GetShape(fields[3]); | |||
| values_type_ = StringToType(fields[4]); | |||
| auto dense_shape_vec = GetShape(fields[5]); | |||
| AbstractBasePtrList dense_shape_list; | |||
| (void)std::transform(dense_shape_vec.begin(), dense_shape_vec.end(), std::back_inserter(dense_shape_list), | |||
| [](const auto &elem) { return FromValue(elem, false); }); | |||
| dense_shape_ = dense_shape_list; | |||
| } | |||
| const std::string ¶m_name() { return param_name_; } | |||
| const std::vector<int> &indices_shape() { return indices_shape_; } | |||
| const TypePtr &indices_type() { return indices_type_; } | |||
| const std::vector<int> &values_shape() { return values_shape_; } | |||
| const TypePtr &values_type() { return values_type_; } | |||
| const AbstractBasePtrList &dense_shape() { return dense_shape_; } | |||
| private: | |||
| std::string param_name_; | |||
| std::vector<int> indices_shape_; | |||
| TypePtr indices_type_; | |||
| std::vector<int> values_shape_; | |||
| TypePtr values_type_; | |||
| AbstractBasePtrList dense_shape_; | |||
| static const size_t fields_num; | |||
| std::vector<int> GetShape(const std::string &shape_str); | |||
| }; | |||
| std::vector<int> UndeterminedShapeType::GetShape(const std::string &shape_str) { | |||
| std::vector<int> ret; | |||
| std::istringstream iss(shape_str); | |||
| int elem; | |||
| while (iss.good()) { | |||
| iss >> elem; | |||
| ret.emplace_back(elem); | |||
| } | |||
| return ret; | |||
| } | |||
| const size_t UndeterminedShapeType::fields_num = 6; | |||
| AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| @@ -62,6 +126,31 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt | |||
| if (type->type_id() != kObjectTypeSymbolicKeyType) { | |||
| MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString(); | |||
| } | |||
| if (key->sparse_grad()) { | |||
| // Will be fixed once undetermined type ready | |||
| auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES"); | |||
| if (sparse_shape_types.empty()) { | |||
| sparse_shape_types = "w1:2:Int32:2 1 2:Float32:3 1 2"; | |||
| } | |||
| MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString() << ", Undetermined shape is " | |||
| << sparse_shape_types; | |||
| auto shape_types = UndeterminedShapeType(sparse_shape_types); | |||
| AbstractBasePtrList sparse_list; | |||
| // indices | |||
| auto indices_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types.indices_type()); | |||
| auto indices = std::make_shared<AbstractTensor>(indices_ele, std::make_shared<Shape>(shape_types.indices_shape())); | |||
| sparse_list.emplace_back(indices); | |||
| // values | |||
| auto dout_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types.values_type()); | |||
| auto dout = std::make_shared<AbstractTensor>(dout_ele, std::make_shared<Shape>(shape_types.values_shape())); | |||
| sparse_list.emplace_back(dout); | |||
| // dense_shape | |||
| sparse_list.emplace_back(std::make_shared<AbstractTuple>(shape_types.dense_shape())); | |||
| return std::make_shared<AbstractTuple>(sparse_list); | |||
| } | |||
| if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) { | |||
| return dflt; | |||
| } | |||
| @@ -80,8 +169,6 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt | |||
| CheckArgsSize(primitive->name(), args_spec_list, 3); | |||
| auto key = args_spec_list[1]; | |||
| auto value = args_spec_list[2]; | |||
| ValuePtr key_value_ptr = key->GetValueTrack(); | |||
| MS_EXCEPTION_IF_NULL(key_value_ptr); | |||
| auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>(); | |||
| @@ -91,7 +178,6 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt | |||
| } | |||
| auto expected = key_value_track->abstract(); | |||
| MS_EXCEPTION_IF_NULL(expected); | |||
| (void)expected->Join(value); | |||
| return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); | |||
| } | |||
| @@ -126,7 +212,9 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| if (type->type_id() != kObjectTypeRefKey) { | |||
| MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); | |||
| } | |||
| return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); | |||
| auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); | |||
| ret->set_sparse_grad(args_spec_list[2]->sparse_grad()); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| @@ -38,6 +38,7 @@ | |||
| #include "pipeline/remove_value_node_dup.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "vm/transform.h" | |||
| #include "parse/python_adapter.h" | |||
| namespace mindspore { | |||
| namespace pipeline { | |||
| @@ -228,6 +229,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||
| if (param_node->has_default()) { | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param()); | |||
| AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true); | |||
| auto sparse_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "sparse_grad")); | |||
| ptr->set_sparse_grad(sparse_grad); | |||
| parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); | |||
| args_spec.push_back(ptr); | |||
| @@ -51,6 +51,7 @@ ValuePtr AbstractBase::BuildValue() const { | |||
| AbstractBasePtr AbstractBase::Broaden() const { | |||
| AbstractBasePtr clone = Clone(); | |||
| clone->set_value(kAnyValue); | |||
| clone->set_sparse_grad(sparse_grad_); | |||
| return clone; | |||
| } | |||
| @@ -63,7 +64,8 @@ std::string AbstractBase::ToString() const { | |||
| MS_EXCEPTION_IF_NULL(type_); | |||
| MS_EXCEPTION_IF_NULL(shape_); | |||
| buffer << type_name() << "(" | |||
| << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << ")"; | |||
| << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() | |||
| << " sparse_grad: " << sparse_grad_ << ")"; | |||
| return buffer.str(); | |||
| } | |||
| @@ -72,16 +74,22 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden() | |||
| AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { | |||
| MS_EXCEPTION_IF_NULL(other); | |||
| if (*this == *other) { | |||
| return shared_from_base<AbstractBase>(); | |||
| auto ret = shared_from_base<AbstractBase>(); | |||
| ret->set_sparse_grad(sparse_grad()); | |||
| return ret; | |||
| } | |||
| auto value_self = GetValueTrack(); | |||
| MS_EXCEPTION_IF_NULL(value_self); | |||
| ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); | |||
| TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack()); | |||
| if (res_value == value_self) { | |||
| return shared_from_base<AbstractBase>(); | |||
| auto ret = shared_from_base<AbstractBase>(); | |||
| ret->set_sparse_grad(sparse_grad()); | |||
| return ret; | |||
| } | |||
| return std::make_shared<AbstractScalar>(res_value, res_type); | |||
| auto ret = std::make_shared<AbstractScalar>(res_value, res_type); | |||
| ret->set_sparse_grad(sparse_grad()); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr AbstractType::Clone() const { | |||
| @@ -423,7 +431,9 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { | |||
| } | |||
| auto element = element_->Join(other_tensor->element_); | |||
| auto shape = ShapeJoin(this->shape(), other_tensor->shape()); | |||
| return std::make_shared<AbstractTensor>(element, shape); | |||
| auto ret = std::make_shared<AbstractTensor>(element, shape); | |||
| ret->set_sparse_grad(sparse_grad()); | |||
| return ret; | |||
| } | |||
| bool AbstractTensor::operator==(const AbstractTensor &other) const { | |||
| @@ -463,6 +473,7 @@ AbstractBasePtr AbstractTensor::Clone() const { | |||
| ShapePtr shp = shape(); | |||
| clone->set_shape(shp->Clone()); | |||
| clone->set_value(GetValueTrack()); | |||
| clone->set_sparse_grad(sparse_grad()); | |||
| return clone; | |||
| } | |||
| @@ -472,6 +483,7 @@ AbstractBasePtr AbstractTensor::Broaden() const { | |||
| auto shp = shape(); | |||
| broaden->set_shape(shp->Clone()); | |||
| broaden->set_value(kAnyValue); | |||
| broaden->set_sparse_grad(sparse_grad()); | |||
| return broaden; | |||
| } | |||
| @@ -482,6 +494,7 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const { | |||
| shp->Broaden(); | |||
| broaden->set_shape(shp); | |||
| broaden->set_value(kAnyValue); | |||
| broaden->set_sparse_grad(sparse_grad()); | |||
| return broaden; | |||
| } | |||
| @@ -502,7 +515,8 @@ std::string AbstractTensor::ToString() const { | |||
| MS_EXCEPTION_IF_NULL(value_track); | |||
| buffer << type_name() << "(" | |||
| << "shape: " << shape_track->ToString() << ", element: " << element_->ToString() | |||
| << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"; | |||
| << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad() | |||
| << ")"; | |||
| return buffer.str(); | |||
| } | |||
| @@ -44,7 +44,7 @@ class AbstractBase : public Base { | |||
| public: | |||
| explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, | |||
| const BaseShapePtr &shape = kNoShape) | |||
| : value_(value), type_(type), shape_(shape) {} | |||
| : value_(value), type_(type), shape_(shape), sparse_grad_(false) {} | |||
| ~AbstractBase() override = default; | |||
| MS_DECLARE_PARENT(AbstractBase, Base) | |||
| @@ -53,11 +53,13 @@ class AbstractBase : public Base { | |||
| virtual bool operator==(const AbstractBase &other) const; | |||
| void set_value(const ValuePtr &value) { value_ = value; } | |||
| void set_sparse_grad(const bool &sparse_grad) { sparse_grad_ = sparse_grad; } | |||
| void set_type(const TypePtr &type) { type_ = type; } | |||
| void set_shape(const BaseShapePtr &shape) { shape_ = shape; } | |||
| void set_value_desc(const std::string &desc) { value_desc_ = desc; } | |||
| const std::string &value_desc() const { return value_desc_; } | |||
| ValuePtr GetValueTrack() const { return value_; } | |||
| bool sparse_grad() const { return sparse_grad_; } | |||
| TypePtr GetTypeTrack() const { return type_; } | |||
| BaseShapePtr GetShapeTrack() const { return shape_; } | |||
| @@ -85,6 +87,7 @@ class AbstractBase : public Base { | |||
| TypePtr type_; | |||
| BaseShapePtr shape_; | |||
| std::string value_desc_; // store initial value description for error report | |||
| bool sparse_grad_; | |||
| }; | |||
| class AbstractScalar : public AbstractBase { | |||
| @@ -851,7 +851,11 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||
| } | |||
| auto refkey = key_value->cast<RefKeyPtr>(); | |||
| if (refkey == nullptr) { | |||
| return std::make_shared<EvalResult>(std::make_shared<AbstractScalar>(type), std::make_shared<AttrValueMap>()); | |||
| auto ret = std::make_shared<AbstractScalar>(type); | |||
| auto ref_value = ref_abs->ref(); | |||
| MS_EXCEPTION_IF_NULL(ref_value); | |||
| ret->set_sparse_grad(ref_value->sparse_grad()); | |||
| return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>()); | |||
| } | |||
| std::string name = refkey->tag(); | |||
| @@ -865,6 +869,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||
| x = SensitivityTransform(x); | |||
| std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x); | |||
| std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type); | |||
| abs_scalar->set_sparse_grad(x->sparse_grad()); | |||
| return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>()); | |||
| } | |||
| }; | |||
| @@ -50,12 +50,14 @@ class Parameter: | |||
| requires_grad (bool): True if the parameter requires gradient. Default: True. | |||
| layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode, | |||
| broadcast and gradients communication would not be applied on parameters. Default: False. | |||
| sparse_grad (bool): True if the parameter's gradient is sparse. Default: False. | |||
| """ | |||
| def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False): | |||
| def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, sparse_grad=False): | |||
| self.set_parameter_data(default_input) | |||
| self.name = name | |||
| self.requires_grad = requires_grad | |||
| self.layerwise_parallel = layerwise_parallel | |||
| self.sparse_grad = sparse_grad | |||
| self._is_init = False | |||
| self._sliced = False | |||
| self.clone_info = _CloneInfo() | |||
| @@ -168,6 +170,17 @@ class Parameter: | |||
| raise TypeError("`requires_grad` parameter must be bool type") | |||
| self._requires_grad = value | |||
| @property | |||
| def sparse_grad(self): | |||
| """Return whether the parameter's gradient is sparse.""" | |||
| return self._sparse_grad | |||
| @sparse_grad.setter | |||
| def sparse_grad(self, value=True): | |||
| if not isinstance(value, bool): | |||
| raise TypeError("`sparse_grad` parameter must be bool type") | |||
| self._sparse_grad = value | |||
| @property | |||
| def data(self): | |||
| return self.default_input | |||
| @@ -30,6 +30,7 @@ unsorted_segment_sum = P.UnsortedSegmentSum() | |||
| transpose = P.Transpose() | |||
| shape_op = P.Shape() | |||
| reshape = P.Reshape() | |||
| size_op = P.Size() | |||
| invert_permutation = P.InvertPermutation() | |||
| logical_and = P.LogicalAnd() | |||
| @@ -284,6 +285,37 @@ def get_bprop_gather_v2(self): | |||
| return bprop | |||
| @bprop_getters.register(P.SparseGatherV2) | |||
| def get_bprop_sparse_gather_v2(self): | |||
| """Generate bprop for SparseGatherV2""" | |||
| def bprop(x, indices, axis, out, dout): | |||
| x_shp = shape_op(x) | |||
| if axis == 0: | |||
| indices_size = (size_op(indices),) | |||
| x_tail_shp = x_shp[1:] | |||
| values_shape = indices_size + x_tail_shp | |||
| values = reshape(dout, values_shape) | |||
| indices = reshape(indices, indices_size) | |||
| return (indices, values, x_shp), zeros_like(indices), zeros_like(axis) | |||
| if F.rank(dout) == 0: | |||
| dout = P.ExpandDims()(dout, -1) | |||
| if F.rank(indices) == 0: | |||
| indices = P.ExpandDims()(indices, -1) | |||
| out_shp = shape_op(dout) | |||
| ind_shp = shape_op(indices) | |||
| # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) | |||
| perm_1 = _generate_shape_index(out_shp, ind_shp, axis) | |||
| values_transpose = transpose(dout, perm_1) | |||
| params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) | |||
| # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) | |||
| perm_2 = _generate_inverse_index(x_shp, axis) | |||
| params_grad = transpose(params_grad, perm_2) | |||
| return params_grad, zeros_like(indices), zeros_like(axis) | |||
| return bprop | |||
| @bprop_getters.register(P.Range) | |||
| def get_bprop_range(self): | |||
| """Generate bprop for Range""" | |||
| @@ -20,7 +20,7 @@ Pre-defined combination of operators. | |||
| """ | |||
| from .base import GradOperation, HyperMap, MultitypeFuncGraph, add_flags, \ | |||
| from .base import GradOperation, HyperMap, Map, MultitypeFuncGraph, add_flags, \ | |||
| grad, grad_all, grad_all_with_sens, grad_by_list, grad_by_list_with_sens, grad_with_sens, \ | |||
| core, env_get, tail, zip_operation | |||
| from .clip_ops import clip_by_value | |||
| @@ -19,7 +19,7 @@ | |||
| from functools import partial | |||
| from mindspore import context | |||
| from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \ | |||
| from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, TensorSlice_, \ | |||
| TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ | |||
| from ...common import dtype as mstype | |||
| from ...common.api import ms_function, _pynative_exec | |||
| @@ -240,6 +240,69 @@ class HyperMap(HyperMap_): | |||
| return func(*args_list) | |||
| return tuple(map(hypermap, *args_list)) | |||
| class Map(Map_): | |||
| """ | |||
| Map will apply the set operation on input sequences. | |||
| Which will apply the operations of every elements of the sequence. | |||
| Args: | |||
| ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`, | |||
| the operations should be putted in the first input of the instance. | |||
| Inputs: | |||
| - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences, | |||
| and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence | |||
| `(args[0][i], args[1][i])` will be the input of the operation. | |||
| If `ops` is not `None`, the first input is the operation, and the other is inputs. | |||
| Outputs: | |||
| sequence, the output will be same type and same length of sequence from input and the value of each element | |||
| is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`. | |||
| """ | |||
| def __init__(self, ops=None): | |||
| self.ops = ops | |||
| if ops: | |||
| Map_.__init__(self, ops) | |||
| else: | |||
| Map_.__init__(self) | |||
| def __call__(self, *args): | |||
| func = args[0] | |||
| count = 0 | |||
| count_max = 1 | |||
| args_list = args[1:] | |||
| if self.ops is not None: | |||
| func = self.ops | |||
| args_list = args | |||
| for item in args_list: | |||
| if isinstance(item, (tuple, list)): | |||
| count_max = len(item) | |||
| break | |||
| def get_item(x): | |||
| nonlocal count | |||
| if isinstance(x, (tuple, list)): | |||
| return x[count] | |||
| return x | |||
| for i in range(count_max): | |||
| true_args = tuple(map(get_item, args_list)) | |||
| func(*true_args) | |||
| count = i + 1 | |||
| return True | |||
| def register(self, *type_names): | |||
| """Register a function for the given type string.""" | |||
| def deco(fn): | |||
| self.register_fn(type_names, fn) | |||
| return fn | |||
| return deco | |||
| class _ListAppend(ListAppend_): | |||
| """ | |||
| A metafuncgraph class that append one element to list. | |||
| @@ -21,7 +21,7 @@ A collection of operators to build nerual networks or computing functions. | |||
| from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| Diag, DiagPart, DType, ExpandDims, Eye, | |||
| Fill, GatherNd, GatherV2, InvertPermutation, | |||
| Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, | |||
| IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, | |||
| Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range, | |||
| SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, | |||
| @@ -122,6 +122,7 @@ __all__ = [ | |||
| 'Transpose', | |||
| 'OneHot', | |||
| 'GatherV2', | |||
| 'SparseGatherV2', | |||
| 'Concat', | |||
| 'Pack', | |||
| 'Unpack', | |||
| @@ -526,6 +526,29 @@ class GatherV2(PrimitiveWithInfer): | |||
| return out | |||
| class SparseGatherV2(GatherV2): | |||
| """ | |||
| Returns a slice of input tensor based on the specified indices and axis. | |||
| Inputs: | |||
| - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||
| The original Tensor. | |||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | |||
| Specifies the indices of elements of the original Tensor. Must be in the range | |||
| `[0, input_param.shape()[axis])`. | |||
| - **axis** (int) - Specifies the dimension index to gather indices. | |||
| Outputs: | |||
| Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. | |||
| Examples: | |||
| >>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32) | |||
| >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32) | |||
| >>> axis = 1 | |||
| >>> out = P.GatherV2()(input_params, input_indices, axis) | |||
| """ | |||
| class Range(PrimitiveWithInfer): | |||
| r""" | |||
| Creates a sequence of numbers. | |||
| @@ -332,6 +332,8 @@ class CheckBprop(PrimitiveWithInfer): | |||
| def infer_shape(self, xshapes, yshapes): | |||
| tips = f'Bprop of {self.prim_to_check}' | |||
| validator.check_value_type('grads', xshapes, (tuple,), tips) | |||
| validator.check_value_type('params', yshapes, (tuple,), tips) | |||
| if len(xshapes) < len(yshapes): | |||
| raise TypeError(f"{tips}, the size of output should be {len(yshapes)}," | |||
| f" but got {len(xshapes)}.") | |||
| @@ -348,6 +350,8 @@ class CheckBprop(PrimitiveWithInfer): | |||
| def infer_dtype(self, xdtypes, ydtypes): | |||
| tips = f'Bprop of {self.prim_to_check}' | |||
| validator.check_value_type('grads', xdtypes, (tuple,), tips) | |||
| validator.check_value_type('params', ydtypes, (tuple,), tips) | |||
| if len(xdtypes) < len(ydtypes): | |||
| raise TypeError(f"{tips}, the size of output should be {len(ydtypes)}," | |||
| f" but got {len(xdtypes)}.") | |||
| @@ -0,0 +1,173 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test adam """ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter, context | |||
| from mindspore.common.api import _executor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import Optimizer | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") | |||
| @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): | |||
| op_mul = P.Mul() | |||
| op_square = P.Square() | |||
| op_sqrt = P.Sqrt() | |||
| op_cast = P.Cast() | |||
| op_reshape = P.Reshape() | |||
| op_shape = P.Shape() | |||
| param_fp32 = op_cast(param, mstype.float32) | |||
| m_fp32 = op_cast(m, mstype.float32) | |||
| v_fp32 = op_cast(v, mstype.float32) | |||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||
| next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) | |||
| next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) | |||
| - beta2, op_square(gradient_fp32)) | |||
| update = next_m / (op_sqrt(next_v) + eps) | |||
| if decay_flag: | |||
| update = update + op_mul(weight_decay_tensor, param_fp32) | |||
| update_with_lr = op_mul(lr, update) | |||
| next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) | |||
| next_v = F.depend(next_v, F.assign(param, next_param)) | |||
| next_v = F.depend(next_v, F.assign(m, next_m)) | |||
| next_v = F.depend(next_v, F.assign(v, next_v)) | |||
| return next_v | |||
| @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tuple", "Bool") | |||
| def _update_run_op_sparse_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): | |||
| return gradient[2][2] | |||
| def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): | |||
| """Check the type of inputs.""" | |||
| validator.check_value_type("beta1", beta1, [float], prim_name) | |||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||
| validator.check_value_type("eps", eps, [float], prim_name) | |||
| validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | |||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||
| class AdamWeightDecaySparse(Optimizer): | |||
| """ | |||
| Implements Adam algorithm weight decay fix. | |||
| Args: | |||
| params (list[Parameter]): A list of parameter, which will be updated. The element in `params` | |||
| should be class mindspore.Parameter. | |||
| learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is | |||
| Iterable or a Tensor and the dims of the Tensor is 1, | |||
| use dynamic learning rate, then the i-th step will | |||
| take the i-th value as the learning rate. | |||
| When the learning_rate is float or learning_rate is a Tensor | |||
| but the dims of the Tensor is 0, use fixed learning rate. | |||
| Other cases are not supported. Default: 1e-3. | |||
| beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9. | |||
| Should be in range (0.0, 1.0). | |||
| beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999. | |||
| Should be in range (0.0, 1.0). | |||
| eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. | |||
| Should be greater than 0. | |||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | |||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | |||
| lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. | |||
| Inputs: | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`, | |||
| and might be in sparse format. | |||
| Outputs: | |||
| tuple[Parameter], the updated velocity value, the shape is the same as `params`. | |||
| Examples: | |||
| >>> net = Net() | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> optim = nn.AdamWeightDecay(params=net.trainable_params()) | |||
| >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | |||
| """ | |||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||
| super(AdamWeightDecaySparse, self).__init__(learning_rate, params) | |||
| if self.is_group: | |||
| raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") | |||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||
| self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) | |||
| self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) | |||
| self.eps = Tensor(np.array([eps]).astype(np.float32)) | |||
| self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) | |||
| self.params = self.parameters | |||
| self.moments1 = self.params.clone(prefix="adam_m", init='zeros') | |||
| self.moments2 = self.params.clone(prefix="adam_v", init='zeros') | |||
| self.decay_flag = tuple(decay_filter(x) for x in self.params) | |||
| self.map = C.Map() | |||
| def construct(self, gradients): | |||
| lr = self.get_lr() | |||
| updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| return updated_velocity | |||
| def test_AdamWeightDecaySparse(): | |||
| """ test_AdamWeightDecaySparse """ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| class Loss(nn.Cell): | |||
| def __init__(self): | |||
| super(Loss, self).__init__() | |||
| def construct(self, base, target): | |||
| return base | |||
| class NetWithSparseGatherV2(nn.Cell): | |||
| def __init__(self): | |||
| super(NetWithSparseGatherV2, self).__init__() | |||
| self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad=True) | |||
| self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") | |||
| self.gatherv2 = P.SparseGatherV2() | |||
| self.axis = 0 | |||
| def construct(self, indices): | |||
| return self.gatherv2(self.w1, indices, self.axis) * self.w2 | |||
| inputs = Tensor(np.array([0, 1]).astype(np.int32)) | |||
| label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) | |||
| net = NetWithSparseGatherV2() | |||
| net.set_train() | |||
| loss = Loss() | |||
| optimizer = AdamWeightDecaySparse(net.trainable_params()) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||
| _executor.compile(train_network, inputs, label) | |||