| @@ -17,6 +17,7 @@ | |||||
| """Resources for ast tree parse.""" | """Resources for ast tree parse.""" | ||||
| import ast | import ast | ||||
| import math | import math | ||||
| from mindspore import IndexedSlices | |||||
| from mindspore.ops.composite import multitype_ops | from mindspore.ops.composite import multitype_ops | ||||
| from mindspore.ops import functional as F, composite as C | from mindspore.ops import functional as F, composite as C | ||||
| from . import standard_method as M | from . import standard_method as M | ||||
| @@ -135,4 +136,7 @@ convert_object_map = { | |||||
| math.sin: NO_IMPLEMENT, | math.sin: NO_IMPLEMENT, | ||||
| math.cos: NO_IMPLEMENT, | math.cos: NO_IMPLEMENT, | ||||
| math.tan: NO_IMPLEMENT, | math.tan: NO_IMPLEMENT, | ||||
| # user defined | |||||
| IndexedSlices: F.make_indexed_slices, | |||||
| } | } | ||||
| @@ -120,6 +120,10 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s | |||||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); | type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); | ||||
| } | } | ||||
| } | } | ||||
| } else if (type->isa<IndexedSlicesType>()) { | |||||
| // Do Nothing | |||||
| } else if (type->isa<UndeterminedType>()) { | |||||
| // Do Nothing | |||||
| } else if (type->isa<Tuple>()) { | } else if (type->isa<Tuple>()) { | ||||
| TuplePtr tuple_type = dyn_cast<Tuple>(type); | TuplePtr tuple_type = dyn_cast<Tuple>(type); | ||||
| type_proto->set_data_type(irpb::DT_TUPLE); | type_proto->set_data_type(irpb::DT_TUPLE); | ||||
| @@ -94,6 +94,48 @@ bool Slice::operator==(const Type &other) const { | |||||
| std::string Slice::DumpText() const { return ToString(); } | std::string Slice::DumpText() const { return ToString(); } | ||||
| TypePtr UndeterminedType::DeepCopy() const { | |||||
| MS_EXCEPTION_IF_NULL(element_type_); | |||||
| if (IsGeneric()) { | |||||
| return std::make_shared<UndeterminedType>(); | |||||
| } | |||||
| return std::make_shared<UndeterminedType>(element_type_->DeepCopy()); | |||||
| } | |||||
| std::string UndeterminedType::ToReprString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "Undetermined"; | |||||
| } | |||||
| return "Undetermined[" + element_type_->ToReprString() + "]"; | |||||
| } | |||||
| std::string UndeterminedType::ToString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "Undetermined"; | |||||
| } | |||||
| return "Undetermined[" + element_type_->ToString() + "]"; | |||||
| } | |||||
| std::string UndeterminedType::DumpText() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "Undetermined"; | |||||
| } | |||||
| return "Undetermined[" + element_type_->DumpText() + "]"; | |||||
| } | |||||
| bool UndeterminedType::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | |||||
| return false; | |||||
| } | |||||
| auto other_elem_type = static_cast<const UndeterminedType &>(other).element_type_; | |||||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||||
| return true; | |||||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return *element_type_ == *other_elem_type; | |||||
| } | |||||
| TypePtr TensorType::DeepCopy() const { | TypePtr TensorType::DeepCopy() const { | ||||
| MS_EXCEPTION_IF_NULL(element_type_); | MS_EXCEPTION_IF_NULL(element_type_); | ||||
| if (IsGeneric()) { | if (IsGeneric()) { | ||||
| @@ -137,6 +179,48 @@ bool TensorType::operator==(const Type &other) const { | |||||
| return *element_type_ == *other_elem_type; | return *element_type_ == *other_elem_type; | ||||
| } | } | ||||
| TypePtr IndexedSlicesType::DeepCopy() const { | |||||
| MS_EXCEPTION_IF_NULL(element_type_); | |||||
| if (IsGeneric()) { | |||||
| return std::make_shared<IndexedSlicesType>(); | |||||
| } | |||||
| return std::make_shared<IndexedSlicesType>(element_type_->DeepCopy()); | |||||
| } | |||||
| std::string IndexedSlicesType::ToReprString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "IndexedSlices"; | |||||
| } | |||||
| return "IndexedSlices[" + element_type_->ToReprString() + "]"; | |||||
| } | |||||
| std::string IndexedSlicesType::ToString() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "IndexedSlices"; | |||||
| } | |||||
| return "IndexedSlices[" + element_type_->ToString() + "]"; | |||||
| } | |||||
| std::string IndexedSlicesType::DumpText() const { | |||||
| if (element_type_ == nullptr) { | |||||
| return "IndexedSlices"; | |||||
| } | |||||
| return "IndexedSlices[" + element_type_->DumpText() + "]"; | |||||
| } | |||||
| bool IndexedSlicesType::operator==(const Type &other) const { | |||||
| if (!IsSameObjectType(*this, other)) { | |||||
| return false; | |||||
| } | |||||
| auto other_elem_type = static_cast<const IndexedSlicesType &>(other).element_type_; | |||||
| if (element_type_ == nullptr && other_elem_type == nullptr) { | |||||
| return true; | |||||
| } else if (element_type_ == nullptr || other_elem_type == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return *element_type_ == *other_elem_type; | |||||
| } | |||||
| Function::Function() : Object(kObjectTypeFunction) { | Function::Function() : Object(kObjectTypeFunction) { | ||||
| args_ = std::vector<TypePtr>(); | args_ = std::vector<TypePtr>(); | ||||
| retval_ = nullptr; | retval_ = nullptr; | ||||
| @@ -108,10 +108,34 @@ class Slice : public Object { | |||||
| }; | }; | ||||
| using SlicePtr = std::shared_ptr<Slice>; | using SlicePtr = std::shared_ptr<Slice>; | ||||
| class UndeterminedType : public Object { | |||||
| public: | |||||
| UndeterminedType() : Object(kObjectTypeUndeterminedType) {} | |||||
| explicit UndeterminedType(const TypePtr &ele) | |||||
| : Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {} | |||||
| ~UndeterminedType() override = default; | |||||
| MS_DECLARE_PARENT(UndeterminedType, Object) | |||||
| TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; } | |||||
| const TypePtr element() const { return element_type_; } | |||||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||||
| TypePtr DeepCopy() const override; | |||||
| std::string ToString() const override; | |||||
| std::string ToReprString() const override; | |||||
| std::string DumpText() const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| protected: | |||||
| TypePtr element_type_; | |||||
| }; | |||||
| using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>; | |||||
| class TensorType : public Object { | class TensorType : public Object { | ||||
| public: | public: | ||||
| TensorType() : Object(kObjectTypeTensorType) {} | |||||
| explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {} | |||||
| TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {} | |||||
| explicit TensorType(const TypePtr &ele) | |||||
| : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||||
| ~TensorType() override = default; | ~TensorType() override = default; | ||||
| MS_DECLARE_PARENT(TensorType, Object) | MS_DECLARE_PARENT(TensorType, Object) | ||||
| @@ -130,6 +154,29 @@ class TensorType : public Object { | |||||
| }; | }; | ||||
| using TensorTypePtr = std::shared_ptr<TensorType>; | using TensorTypePtr = std::shared_ptr<TensorType>; | ||||
| class IndexedSlicesType : public Object { | |||||
| public: | |||||
| IndexedSlicesType() : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType) {} | |||||
| explicit IndexedSlicesType(const TypePtr &ele) | |||||
| : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType, false), element_type_(ele) {} | |||||
| ~IndexedSlicesType() override = default; | |||||
| MS_DECLARE_PARENT(IndexedSlicesType, Object) | |||||
| TypeId generic_type_id() const override { return kObjectTypeIndexedSlicesType; } | |||||
| const TypePtr element() const { return element_type_; } | |||||
| void set_element(const TypePtr &element_type) { element_type_ = element_type; } | |||||
| TypePtr DeepCopy() const override; | |||||
| std::string ToString() const override; | |||||
| std::string ToReprString() const override; | |||||
| std::string DumpText() const override; | |||||
| bool operator==(const Type &other) const override; | |||||
| private: | |||||
| TypePtr element_type_; | |||||
| }; | |||||
| using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>; | |||||
| class Function : public Object { | class Function : public Object { | ||||
| public: | public: | ||||
| Function(); | Function(); | ||||
| @@ -255,6 +302,8 @@ TypePtr StringToType(const std::string &type_name); | |||||
| // Judge whether x is predicate or is a subclass of predicate. | // Judge whether x is predicate or is a subclass of predicate. | ||||
| bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type); | bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type); | ||||
| bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type); | |||||
| // Whether t1 is identity or a subclass of t2. | // Whether t1 is identity or a subclass of t2. | ||||
| bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr); | bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr); | ||||
| @@ -115,6 +115,10 @@ const char *ObjectIdLabel(const TypeId &v) { | |||||
| return "kObjectTypeKeyword"; | return "kObjectTypeKeyword"; | ||||
| case kObjectTypeTensorType: | case kObjectTypeTensorType: | ||||
| return "kObjectTypeTensorType"; | return "kObjectTypeTensorType"; | ||||
| case kObjectTypeIndexedSlicesType: | |||||
| return "kObjectTypeIndexedSlicesType"; | |||||
| case kObjectTypeUndeterminedType: | |||||
| return "kObjectTypeUndeterminedType"; | |||||
| case kObjectTypeDictionary: | case kObjectTypeDictionary: | ||||
| return "kObjectTypeDictionary"; | return "kObjectTypeDictionary"; | ||||
| case kObjectTypeClass: | case kObjectTypeClass: | ||||
| @@ -67,6 +67,7 @@ class Type : public Value { | |||||
| virtual bool equal(const TypePtr other) const { return *this == *other; } | virtual bool equal(const TypePtr other) const { return *this == *other; } | ||||
| virtual TypeId object_type() const { return kTypeUnknown; } | virtual TypeId object_type() const { return kTypeUnknown; } | ||||
| virtual TypeId parent_type() const { return kTypeUnknown; } | |||||
| virtual TypeId number_type() const { return kTypeUnknown; } | virtual TypeId number_type() const { return kTypeUnknown; } | ||||
| virtual TypePtr DeepCopy() const = 0; | virtual TypePtr DeepCopy() const = 0; | ||||
| virtual TypePtr Clone() const { return DeepCopy(); } | virtual TypePtr Clone() const { return DeepCopy(); } | ||||
| @@ -97,13 +98,16 @@ using TypePtrList = std::vector<TypePtr>; | |||||
| // | // | ||||
| class Object : public Type { | class Object : public Type { | ||||
| public: | public: | ||||
| Object() : Type(kMetaTypeObject), object_type_(kMetaTypeObject) {} | |||||
| Object() : Type(kMetaTypeObject), object_type_(kMetaTypeObject), parent_type_(kMetaTypeObject) {} | |||||
| explicit Object(const TypeId object_type, bool is_generic = true) | explicit Object(const TypeId object_type, bool is_generic = true) | ||||
| : Type(kMetaTypeObject, is_generic), object_type_(object_type) {} | |||||
| : Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(kMetaTypeObject) {} | |||||
| explicit Object(const TypeId object_type, const TypeId parent_type, bool is_generic = true) | |||||
| : Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(parent_type) {} | |||||
| ~Object() override = default; | ~Object() override = default; | ||||
| MS_DECLARE_PARENT(Object, Type) | MS_DECLARE_PARENT(Object, Type) | ||||
| TypeId object_type() const override { return object_type_; } | TypeId object_type() const override { return object_type_; } | ||||
| TypeId parent_type() const override { return parent_type_; } | |||||
| TypeId type_id() const override { return object_type_; } | TypeId type_id() const override { return object_type_; } | ||||
| TypeId generic_type_id() const override { return kMetaTypeObject; } | TypeId generic_type_id() const override { return kMetaTypeObject; } | ||||
| bool equal(const TypePtr other) const override; | bool equal(const TypePtr other) const override; | ||||
| @@ -114,6 +118,7 @@ class Object : public Type { | |||||
| private: | private: | ||||
| const TypeId object_type_; | const TypeId object_type_; | ||||
| const TypeId parent_type_; | |||||
| }; | }; | ||||
| std::ostream &operator<<(std::ostream &os, const TypePtrList &types); | std::ostream &operator<<(std::ostream &os, const TypePtrList &types); | ||||
| @@ -50,6 +50,8 @@ enum TypeId : int { | |||||
| kObjectTypeSlice, | kObjectTypeSlice, | ||||
| kObjectTypeKeyword, | kObjectTypeKeyword, | ||||
| kObjectTypeTensorType, | kObjectTypeTensorType, | ||||
| kObjectTypeIndexedSlicesType, | |||||
| kObjectTypeUndeterminedType, | |||||
| kObjectTypeClass, | kObjectTypeClass, | ||||
| kObjectTypeDictionary, | kObjectTypeDictionary, | ||||
| kObjectTypeFunction, | kObjectTypeFunction, | ||||
| @@ -192,6 +192,40 @@ TypePtr TensorStrToType(const std::string &type_name) { | |||||
| return type; | return type; | ||||
| } | } | ||||
| TypePtr IndexedSlicesStrToType(const std::string &type_name) { | |||||
| if (type_name == "IndexedSlices") { | |||||
| return std::make_shared<IndexedSlicesType>(); | |||||
| } | |||||
| auto start = type_name.find_first_of('[') + 1; | |||||
| auto end = type_name.find_last_of(']'); | |||||
| if (start >= type_name.size()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto element_str = type_name.substr(start, end - start); | |||||
| auto element_type = StringToType(element_str); | |||||
| if (element_type == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| return std::make_shared<IndexedSlicesType>(element_type); | |||||
| } | |||||
| TypePtr UndeterminedStrToType(const std::string &type_name) { | |||||
| if (type_name == "Undetermined") { | |||||
| return std::make_shared<UndeterminedType>(); | |||||
| } | |||||
| auto start = type_name.find_first_of('[') + 1; | |||||
| auto end = type_name.find_last_of(']'); | |||||
| if (start >= type_name.size()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto element_str = type_name.substr(start, end - start); | |||||
| auto element_type = StringToType(element_str); | |||||
| if (element_type == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| return std::make_shared<UndeterminedType>(element_type); | |||||
| } | |||||
| TypePtr ListStrToType(const std::string &type_name) { | TypePtr ListStrToType(const std::string &type_name) { | ||||
| TypePtr type = nullptr; | TypePtr type = nullptr; | ||||
| if (type_name == "List") { | if (type_name == "List") { | ||||
| @@ -313,6 +347,10 @@ TypePtr StringToType(const std::string &type_name) { | |||||
| type = StringToNumberType<Float>(type_name, "Float"); | type = StringToNumberType<Float>(type_name, "Float"); | ||||
| } else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) { | } else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) { | ||||
| type = TensorStrToType(type_name); | type = TensorStrToType(type_name); | ||||
| } else if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) { | |||||
| type = UndeterminedStrToType(type_name); | |||||
| } else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) { | |||||
| type = IndexedSlicesStrToType(type_name); | |||||
| } else if (type_name.compare(0, strlen("List"), "List") == 0) { | } else if (type_name.compare(0, strlen("List"), "List") == 0) { | ||||
| type = ListStrToType(type_name); | type = ListStrToType(type_name); | ||||
| } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { | } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { | ||||
| @@ -340,6 +378,20 @@ TypePtr StringToType(const std::string &type_name) { | |||||
| return type; | return type; | ||||
| } | } | ||||
| bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type) { | |||||
| if (x == nullptr || base_type == nullptr) { | |||||
| MS_LOG(ERROR) << "Type is nullptr."; | |||||
| return false; | |||||
| } | |||||
| if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) { | |||||
| return false; | |||||
| } | |||||
| if (base_type->type_id() == x->parent_type() || x->type_id() == base_type->parent_type()) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { | bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { | ||||
| if (x == nullptr || base_type == nullptr) { | if (x == nullptr || base_type == nullptr) { | ||||
| MS_LOG(ERROR) << "Type is nullptr."; | MS_LOG(ERROR) << "Type is nullptr."; | ||||
| @@ -481,6 +533,10 @@ REGISTER_PYBIND_DEFINE( | |||||
| TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>())))); | TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>())))); | ||||
| return data; | return data; | ||||
| })); | })); | ||||
| (void)py::class_<IndexedSlicesType, Type, std::shared_ptr<IndexedSlicesType>>(m_sub, "IndexedSlicesType") | |||||
| .def(py::init()); | |||||
| (void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType") | |||||
| .def(py::init()); | |||||
| (void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function") | (void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function") | ||||
| .def(py::init()) | .def(py::init()) | ||||
| .def(py::init<std::vector<TypePtr>, TypePtr>(), py::arg("args"), py::arg("retval")); | .def(py::init<std::vector<TypePtr>, TypePtr>(), py::arg("args"), py::arg("retval")); | ||||
| @@ -501,6 +557,8 @@ const TypePtr kTypeExternal = std::make_shared<External>(); | |||||
| const TypePtr kTypeEnv = std::make_shared<EnvType>(); | const TypePtr kTypeEnv = std::make_shared<EnvType>(); | ||||
| const TypePtr kTypeType = std::make_shared<TypeType>(); | const TypePtr kTypeType = std::make_shared<TypeType>(); | ||||
| const TypePtr kTensorType = std::make_shared<TensorType>(); | const TypePtr kTensorType = std::make_shared<TensorType>(); | ||||
| const TypePtr kIndexedSlicesType = std::make_shared<IndexedSlicesType>(); | |||||
| const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>(); | |||||
| const TypePtr kString = std::make_shared<String>(); | const TypePtr kString = std::make_shared<String>(); | ||||
| const TypePtr kList = std::make_shared<List>(); | const TypePtr kList = std::make_shared<List>(); | ||||
| const TypePtr kTuple = std::make_shared<Tuple>(); | const TypePtr kTuple = std::make_shared<Tuple>(); | ||||
| @@ -93,15 +93,17 @@ static TypePtr UnwrapRef(const TypePtr &type) { | |||||
| } | } | ||||
| return type; | return type; | ||||
| } | } | ||||
| FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { | |||||
| bool find_fn = false; | |||||
| py::function py_fn; | |||||
| // Return Exact match if exists, else return non ambiguous sub class match | |||||
| // Return py::none() if matching is ambiguous | |||||
| const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { | |||||
| // Exact match | |||||
| for (auto &item : fn_cache_py_) { | for (auto &item : fn_cache_py_) { | ||||
| TypePtrList sign = item.first; | TypePtrList sign = item.first; | ||||
| if (sign.size() != types.size()) { | if (sign.size() != types.size()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| bool match = true; | |||||
| auto match = true; | |||||
| for (size_t i = 0; i < sign.size(); ++i) { | for (size_t i = 0; i < sign.size(); ++i) { | ||||
| if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) { | if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) { | ||||
| match = false; | match = false; | ||||
| @@ -111,13 +113,45 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { | |||||
| if (!match) { | if (!match) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| find_fn = true; | |||||
| py_fn = item.second; | |||||
| break; | |||||
| return item.second; | |||||
| } | } | ||||
| // Try best match | |||||
| py::function py_fn_subclass; | |||||
| size_t subclass_match_cnt = 0; | |||||
| for (auto &item : fn_cache_py_) { | |||||
| TypePtrList sign = item.first; | |||||
| if (sign.size() != types.size()) { | |||||
| continue; | |||||
| } | |||||
| auto match = true; | |||||
| for (size_t i = 0; i < sign.size(); ++i) { | |||||
| if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i]) && | |||||
| !IsParentOrChildrenType(UnwrapRef(types[i]), sign[i])) { | |||||
| match = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!match) { | |||||
| continue; | |||||
| } | |||||
| py_fn_subclass = item.second; | |||||
| subclass_match_cnt++; | |||||
| } | |||||
| if (subclass_match_cnt > 1) { | |||||
| MS_LOG(EXCEPTION) << "There are more than one prototypes for overload function match by subclass"; | |||||
| } | |||||
| if (subclass_match_cnt == 1) { | |||||
| MS_LOG(DEBUG) << "Found one subclass match"; | |||||
| return py_fn_subclass; | |||||
| } | |||||
| return py::none(); | |||||
| } | |||||
| FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { | |||||
| auto py_fn = SignMatch(types); | |||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| buffer << types; | buffer << types; | ||||
| if (find_fn) { | |||||
| if (py_fn != py::none()) { | |||||
| FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn); | FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn); | ||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str(); | MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str(); | ||||
| @@ -54,6 +54,7 @@ class MultitypeFuncGraph : public MetaFuncGraph { | |||||
| } | } | ||||
| private: | private: | ||||
| const py::function SignMatch(const TypePtrList &types); | |||||
| std::unordered_map<TypePtrList, specialize_fn, TypeListHasher, TypeListEqual> fn_cache_; | std::unordered_map<TypePtrList, specialize_fn, TypeListHasher, TypeListEqual> fn_cache_; | ||||
| std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> fn_cache_py_; | std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> fn_cache_py_; | ||||
| }; | }; | ||||
| @@ -277,5 +277,12 @@ const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary | |||||
| const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary"); | const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary"); | ||||
| const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary"); | const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary"); | ||||
| const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug"); | const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug"); | ||||
| // IndexedSlices | |||||
| const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices"); | |||||
| const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues"); | |||||
| const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices"); | |||||
| const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape"); | |||||
| const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices"); | |||||
| } // namespace prim | } // namespace prim | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -287,6 +287,13 @@ extern const PrimitivePtr kPrimMirror; | |||||
| extern const PrimitivePtr kPrimVirtualDiv; | extern const PrimitivePtr kPrimVirtualDiv; | ||||
| extern const PrimitivePtr kPrimVirtualDataset; | extern const PrimitivePtr kPrimVirtualDataset; | ||||
| // IndexedSlices | |||||
| extern const PrimitivePtr kPrimMakeIndexedSlices; | |||||
| extern const PrimitivePtr kPrimIndexedSlicesGetValues; | |||||
| extern const PrimitivePtr kPrimIndexedSlicesGetIndices; | |||||
| extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape; | |||||
| extern const PrimitivePtr kPrimIsIndexedSlices; | |||||
| class DoSignaturePrimitive : public Primitive { | class DoSignaturePrimitive : public Primitive { | ||||
| public: | public: | ||||
| explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) | explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include "pipeline/static_analysis/prim.h" | #include "pipeline/static_analysis/prim.h" | ||||
| #include "pipeline/static_analysis/utils.h" | #include "pipeline/static_analysis/utils.h" | ||||
| #include "utils/symbolic.h" | #include "utils/symbolic.h" | ||||
| #include "utils/context/ms_context.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace abstract { | namespace abstract { | ||||
| @@ -173,6 +174,13 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt | |||||
| return std::make_shared<AbstractTuple>(sparse_list); | return std::make_shared<AbstractTuple>(sparse_list); | ||||
| } | } | ||||
| auto context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context); | |||||
| bool enable_sparse_flag = context->enable_sparse_flag(); | |||||
| if (enable_sparse_flag && key->has_indexed_slices_grad() && dflt->isa<AbstractTensor>()) { | |||||
| auto dflt_tensor = dflt->cast<AbstractTensorPtr>(); | |||||
| return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone()); | |||||
| } | |||||
| if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) { | if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) { | ||||
| return dflt; | return dflt; | ||||
| } | } | ||||
| @@ -236,6 +244,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr & | |||||
| } | } | ||||
| auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); | auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); | ||||
| ret->set_sparse_grad(args_spec_list[2]->sparse_grad()); | ret->set_sparse_grad(args_spec_list[2]->sparse_grad()); | ||||
| ret->set_has_indexed_slices_grad(args_spec_list[2]->has_indexed_slices_grad()); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -437,5 +446,72 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv | |||||
| } | } | ||||
| return std::make_shared<AbstractScalar>(kAnyValue, kBool); | return std::make_shared<AbstractScalar>(kAnyValue, kBool); | ||||
| } | } | ||||
| AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: two tensors and a tuple. | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 3); | |||||
| auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||||
| auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2); | |||||
| auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(dense_shape_value); | |||||
| auto shp = dense_shape_value->value(); | |||||
| std::vector<int> dense_shape_vec; | |||||
| (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec), | |||||
| [](const ValuePtr &e) -> int { | |||||
| auto elem = GetValue<int>(e); | |||||
| return elem; | |||||
| }); | |||||
| auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec); | |||||
| ret->set_indices(indices); | |||||
| ret->set_values(values); | |||||
| ret->set_dense_shape(dense_shape); | |||||
| return ret; | |||||
| } | |||||
| AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: two tensors and a tuple. | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | |||||
| auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(indexed_slices->values()); | |||||
| return indexed_slices->values(); | |||||
| } | |||||
| AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: two tensors and a tuple. | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | |||||
| auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(indexed_slices->indices()); | |||||
| return indexed_slices->indices(); | |||||
| } | |||||
| AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // Inputs: two tensors and a tuple. | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | |||||
| auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(indexed_slices->dense_shape()); | |||||
| return indexed_slices->dense_shape(); | |||||
| } | |||||
| AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 1); | |||||
| bool ret = false; | |||||
| if (args_spec_list[0]->isa<AbstractIndexedSlices>()) { | |||||
| ret = true; | |||||
| } | |||||
| MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString(); | |||||
| return std::make_shared<AbstractScalar>(ret); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,6 +36,7 @@ using mindspore::abstract::AbstractJTagged; | |||||
| using mindspore::abstract::AbstractList; | using mindspore::abstract::AbstractList; | ||||
| using mindspore::abstract::AbstractScalar; | using mindspore::abstract::AbstractScalar; | ||||
| using mindspore::abstract::AbstractTuple; | using mindspore::abstract::AbstractTuple; | ||||
| using mindspore::abstract::AbstractUndetermined; | |||||
| static AbstractBasePtr Reabs(const AbstractBasePtr &t) { | static AbstractBasePtr Reabs(const AbstractBasePtr &t) { | ||||
| if (t == nullptr) { | if (t == nullptr) { | ||||
| @@ -78,7 +79,7 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(cons); | MS_EXCEPTION_IF_NULL(cons); | ||||
| auto dt = data->abstract(); | auto dt = data->abstract(); | ||||
| if (dt == nullptr) { | |||||
| if (dt == nullptr || dt->BuildType()->type_id() == kObjectTypeUndeterminedType) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -42,6 +42,7 @@ | |||||
| #include "optimizer/irpass/tile_eliminate.h" | #include "optimizer/irpass/tile_eliminate.h" | ||||
| #include "optimizer/irpass/transpose_eliminate.h" | #include "optimizer/irpass/transpose_eliminate.h" | ||||
| #include "optimizer/opt.h" | #include "optimizer/opt.h" | ||||
| #include "optimizer/irpass/indexed_slices_eliminate.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -153,6 +154,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| // Mark interface fusion | // Mark interface fusion | ||||
| mark_interface_fusion_ = | mark_interface_fusion_ = | ||||
| MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect); | MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect); | ||||
| // IndexedSlices Eliminate | |||||
| indexed_slices_eliminate_ = MakeSubstitution( | |||||
| std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate", | |||||
| {prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape}); | |||||
| } | } | ||||
| ResolveIRPassLib::ResolveIRPassLib() { | ResolveIRPassLib::ResolveIRPassLib() { | ||||
| @@ -104,6 +104,9 @@ class OptimizeIRPassLib { | |||||
| // Fusion | // Fusion | ||||
| SubstitutionPtr mark_interface_fusion_; | SubstitutionPtr mark_interface_fusion_; | ||||
| // IndexedSlices Eliminate | |||||
| SubstitutionPtr indexed_slices_eliminate_; | |||||
| }; | }; | ||||
| // the collection of irpass for resolve action | // the collection of irpass for resolve action | ||||
| @@ -0,0 +1,75 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ | |||||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include "optimizer/irpass.h" | |||||
| #include "optimizer/optimizer.h" | |||||
| #include "ir/visitor.h" | |||||
| #include "operator/ops.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace irpass { | |||||
| // {prim::kPrimIndexedSlicesGetIndices, {prim::kPrimMakeIndexedSlices, Xs}} | |||||
| // {prim::kPrimIndexedSlicesGetValues, {prim::kPrimMakeIndexedSlices, Xs}} | |||||
| // {prim::kPrimIndexedSlicesGetDenseShape, {prim::kPrimMakeIndexedSlices, Xs}} | |||||
| class IndexedSlicesEliminater : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| Reset(); | |||||
| AnfVisitor::Match(prim::kPrimIndexedSlicesGetIndices, {IsCNode})(node); | |||||
| if (is_match_) { | |||||
| return tuple_->input(1); | |||||
| } | |||||
| AnfVisitor::Match(prim::kPrimIndexedSlicesGetValues, {IsCNode})(node); | |||||
| if (is_match_) { | |||||
| return tuple_->input(2); | |||||
| } | |||||
| AnfVisitor::Match(prim::kPrimIndexedSlicesGetDenseShape, {IsCNode})(node); | |||||
| if (is_match_) { | |||||
| return tuple_->input(3); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| void Visit(const CNodePtr &cnode) override { | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimMakeIndexedSlices)) { | |||||
| tuple_ = cnode; | |||||
| is_match_ = true; | |||||
| } | |||||
| } | |||||
| void Reset() { | |||||
| tuple_ = nullptr; | |||||
| is_match_ = false; | |||||
| } | |||||
| private: | |||||
| bool is_match_{false}; | |||||
| CNodePtr tuple_{nullptr}; | |||||
| }; | |||||
| } // namespace irpass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ | |||||
| @@ -232,6 +232,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||||
| auto sparse_grad = | auto sparse_grad = | ||||
| py::cast<std::string>(parse::python_adapter::GetPyObjAttr(param_value->value(), "sparse_grad")); | py::cast<std::string>(parse::python_adapter::GetPyObjAttr(param_value->value(), "sparse_grad")); | ||||
| ptr->set_sparse_grad(sparse_grad); | ptr->set_sparse_grad(sparse_grad); | ||||
| auto has_indexed_slices_grad = | |||||
| py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "has_indexed_slices_grad")); | |||||
| ptr->set_has_indexed_slices_grad(has_indexed_slices_grad); | |||||
| parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); | parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); | ||||
| args_spec.push_back(ptr); | args_spec.push_back(ptr); | ||||
| @@ -154,7 +154,9 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.") | .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.") | ||||
| .def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel, | .def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel, | ||||
| "Set the GraphKernel switch to on or off.") | "Set the GraphKernel switch to on or off.") | ||||
| .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch."); | |||||
| .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.") | |||||
| .def("get_enable_sparse_flag", &mindspore::MsContext::enable_sparse_flag, "Get whether to enable sparse.") | |||||
| .def("set_enable_sparse_flag", &mindspore::MsContext::set_enable_sparse_flag, "Set whether to enable sparse."); | |||||
| (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") | ||||
| .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") | ||||
| @@ -156,6 +156,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| irpass.replace_refkey_by_param_, | irpass.replace_refkey_by_param_, | ||||
| irpass.make_ref_eliminate_, | irpass.make_ref_eliminate_, | ||||
| irpass.get_ref_param_eliminate_, | irpass.get_ref_param_eliminate_, | ||||
| irpass.indexed_slices_eliminate_, | |||||
| }); | }); | ||||
| OptPassGroupMap map({ | OptPassGroupMap map({ | ||||
| {"b_1", b_1}, | {"b_1", b_1}, | ||||
| @@ -33,148 +33,157 @@ namespace mindspore { | |||||
| namespace pipeline { | namespace pipeline { | ||||
| MethodMap &GetMethodMap() { | MethodMap &GetMethodMap() { | ||||
| static MethodMap method_map = {{kObjectTypeString, | |||||
| { | |||||
| {"__bool__", std::string("str_bool")} // C.str_bool | |||||
| }}, | |||||
| {kMetaTypeNone, | |||||
| { | |||||
| {"__bool__", std::string("none_bool")} // C.none_bool | |||||
| }}, | |||||
| {kNumberTypeBool, | |||||
| { | |||||
| {"__and__", prim::kPrimBoolAnd}, // P.bool_and | |||||
| {"__or__", prim::kPrimBoolOr}, // P.bool_or | |||||
| {"__eq__", prim::kPrimBoolEq}, // P.bool_eq | |||||
| {"__ne__", std::string("bool_ne")}, // C.bool_ne | |||||
| {"__bool__", prim::kPrimIdentity} // P.identity | |||||
| }}, | |||||
| {kNumberTypeInt, | |||||
| { | |||||
| {"__add__", prim::kPrimScalarAdd}, // P.scalar_add | |||||
| {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub | |||||
| {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul | |||||
| {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv | |||||
| {"__truediv__", std::string("int_truediv")}, // C.int_truediv | |||||
| {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod | |||||
| {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow | |||||
| {"__floor__", prim::kPrimIdentity}, // P.identity | |||||
| {"__trunc__", prim::kPrimIdentity}, // P.identity | |||||
| {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd | |||||
| {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub | |||||
| {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq | |||||
| {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne | |||||
| {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt | |||||
| {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt | |||||
| {"__le__", prim::kPrimScalarLe}, // P.scalar_le | |||||
| {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge | |||||
| {"__bool__", std::string("int_bool")}, // C.int_bool | |||||
| {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array | |||||
| }}, | |||||
| {kNumberTypeUInt, | |||||
| { | |||||
| {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, | |||||
| {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, | |||||
| {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, | |||||
| {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div, | |||||
| {"__truediv__", std::string("int_truediv")}, // C.int_truediv | |||||
| {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, | |||||
| {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, | |||||
| {"__floor__", prim::kPrimIdentity}, // P.identity, | |||||
| {"__trunc__", prim::kPrimIdentity}, // P.identity, | |||||
| {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, | |||||
| {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, | |||||
| {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, | |||||
| {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, | |||||
| {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, | |||||
| {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, | |||||
| {"__le__", prim::kPrimScalarLe}, // P.scalar_le, | |||||
| {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, | |||||
| {"__bool__", std::string("int_bool")}, // C.int_bool | |||||
| {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, | |||||
| }}, | |||||
| {kNumberTypeFloat, | |||||
| { | |||||
| {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, | |||||
| {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, | |||||
| {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, | |||||
| {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv | |||||
| {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div, | |||||
| {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, | |||||
| {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, | |||||
| {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor, | |||||
| {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc, | |||||
| {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, | |||||
| {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, | |||||
| {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, | |||||
| {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, | |||||
| {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, | |||||
| {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, | |||||
| {"__le__", prim::kPrimScalarLe}, // P.scalar_le, | |||||
| {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, | |||||
| {"__bool__", std::string("float_bool")}, // C.float_bool | |||||
| {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, | |||||
| }}, | |||||
| {kObjectTypeTuple, | |||||
| { | |||||
| {"__len__", prim::kPrimTupleLen}, // P.tuple_len, | |||||
| {"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem, | |||||
| {"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem, | |||||
| {"__ms_iter__", prim::kPrimIdentity}, // P.identity, | |||||
| {"__ms_next__", std::string("tuple_next")}, // C.tuple_next, | |||||
| {"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext | |||||
| {"__bool__", std::string("tuple_bool")} // C.tuple_bool | |||||
| }}, | |||||
| {kObjectTypeList, | |||||
| { | |||||
| {"__len__", prim::kPrimListLen}, // P.list_len, | |||||
| {"__getitem__", prim::kPrimListGetItem}, // P.list_getitem, | |||||
| {"__setitem__", prim::kPrimListSetItem}, // P.list_setitem, | |||||
| {"__ms_iter__", prim::kPrimIdentity}, // P.identity | |||||
| {"__ms_next__", std::string("list_next")}, // C.list_next | |||||
| {"append", std::string("list_append")}, // C.list_next | |||||
| {"__bool__", std::string("list_bool")}, // C.list_bool | |||||
| {"__ms_hasnext__", std::string("list_hasnext")}, | |||||
| }}, | |||||
| {kObjectTypeDictionary, | |||||
| { | |||||
| {"__len__", prim::kPrimDictLen}, // P.dict_len | |||||
| {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem | |||||
| {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, | |||||
| {"__bool__", std::string("dict_bool")} // C.dict_bool | |||||
| }}, | |||||
| {kObjectTypeTensorType, | |||||
| { | |||||
| {"__add__", std::string("add")}, // C.add | |||||
| {"__sub__", std::string("sub")}, // C.sub | |||||
| {"__mul__", std::string("mul")}, // C.mul | |||||
| {"__truediv__", std::string("truediv")}, // C.truediv | |||||
| {"__floordiv__", std::string("floordiv")}, // C.floordiv | |||||
| {"__mod__", std::string("mod")}, // C.mod | |||||
| {"__pow__", std::string("pow_")}, // C.pow | |||||
| {"__floor__", std::string("array_floor")}, // C.array_floor | |||||
| {"__trunc__", std::string("array_trunc")}, // C.array_trunc | |||||
| {"__pos__", std::string("array_uadd")}, // C.array_uadd | |||||
| {"__neg__", std::string("array_usub")}, // C.array_usub | |||||
| {"__eq__", std::string("eq")}, // C.eq | |||||
| {"__ne__", std::string("ne")}, // C.ne | |||||
| {"__lt__", std::string("lt")}, // C.lt | |||||
| {"__gt__", std::string("gt")}, // C.gt | |||||
| {"__le__", std::string("le")}, // C.le | |||||
| {"__ge__", std::string("ge")}, // C.ge | |||||
| {"__matmul__", prim::kPrimDot}, // P.dot, | |||||
| {"__len__", prim::kPrimArrayLen}, // P.array_len, | |||||
| {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, | |||||
| {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, | |||||
| {"__ms_iter__", std::string("array_iter")}, // C.array_iter | |||||
| {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, | |||||
| {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, | |||||
| {"transpose", std::string("transpose")}, // P.transpose | |||||
| {"__bool__", std::string("tensor_bool")}, // C.tensor_bool | |||||
| }}, | |||||
| {kObjectTypeJTagged, {}}, | |||||
| {kObjectTypeSymbolicKeyType, {}}, | |||||
| {kObjectTypeEnvType, {}}}; | |||||
| static MethodMap method_map = { | |||||
| {kObjectTypeString, | |||||
| { | |||||
| {"__bool__", std::string("str_bool")} // C.str_bool | |||||
| }}, | |||||
| {kMetaTypeNone, | |||||
| { | |||||
| {"__bool__", std::string("none_bool")} // C.none_bool | |||||
| }}, | |||||
| {kNumberTypeBool, | |||||
| { | |||||
| {"__and__", prim::kPrimBoolAnd}, // P.bool_and | |||||
| {"__or__", prim::kPrimBoolOr}, // P.bool_or | |||||
| {"__eq__", prim::kPrimBoolEq}, // P.bool_eq | |||||
| {"__ne__", std::string("bool_ne")}, // C.bool_ne | |||||
| {"__bool__", prim::kPrimIdentity} // P.identity | |||||
| }}, | |||||
| {kNumberTypeInt, | |||||
| { | |||||
| {"__add__", prim::kPrimScalarAdd}, // P.scalar_add | |||||
| {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub | |||||
| {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul | |||||
| {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv | |||||
| {"__truediv__", std::string("int_truediv")}, // C.int_truediv | |||||
| {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod | |||||
| {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow | |||||
| {"__floor__", prim::kPrimIdentity}, // P.identity | |||||
| {"__trunc__", prim::kPrimIdentity}, // P.identity | |||||
| {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd | |||||
| {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub | |||||
| {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq | |||||
| {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne | |||||
| {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt | |||||
| {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt | |||||
| {"__le__", prim::kPrimScalarLe}, // P.scalar_le | |||||
| {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge | |||||
| {"__bool__", std::string("int_bool")}, // C.int_bool | |||||
| {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array | |||||
| }}, | |||||
| {kNumberTypeUInt, | |||||
| { | |||||
| {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, | |||||
| {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, | |||||
| {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, | |||||
| {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div, | |||||
| {"__truediv__", std::string("int_truediv")}, // C.int_truediv | |||||
| {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, | |||||
| {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, | |||||
| {"__floor__", prim::kPrimIdentity}, // P.identity, | |||||
| {"__trunc__", prim::kPrimIdentity}, // P.identity, | |||||
| {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, | |||||
| {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, | |||||
| {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, | |||||
| {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, | |||||
| {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, | |||||
| {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, | |||||
| {"__le__", prim::kPrimScalarLe}, // P.scalar_le, | |||||
| {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, | |||||
| {"__bool__", std::string("int_bool")}, // C.int_bool | |||||
| {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, | |||||
| }}, | |||||
| {kNumberTypeFloat, | |||||
| { | |||||
| {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, | |||||
| {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, | |||||
| {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, | |||||
| {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv | |||||
| {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div, | |||||
| {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, | |||||
| {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, | |||||
| {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor, | |||||
| {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc, | |||||
| {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, | |||||
| {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, | |||||
| {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, | |||||
| {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, | |||||
| {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, | |||||
| {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, | |||||
| {"__le__", prim::kPrimScalarLe}, // P.scalar_le, | |||||
| {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, | |||||
| {"__bool__", std::string("float_bool")}, // C.float_bool | |||||
| {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, | |||||
| }}, | |||||
| {kObjectTypeTuple, | |||||
| { | |||||
| {"__len__", prim::kPrimTupleLen}, // P.tuple_len, | |||||
| {"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem, | |||||
| {"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem, | |||||
| {"__ms_iter__", prim::kPrimIdentity}, // P.identity, | |||||
| {"__ms_next__", std::string("tuple_next")}, // C.tuple_next, | |||||
| {"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext | |||||
| {"__bool__", std::string("tuple_bool")} // C.tuple_bool | |||||
| }}, | |||||
| {kObjectTypeList, | |||||
| { | |||||
| {"__len__", prim::kPrimListLen}, // P.list_len, | |||||
| {"__getitem__", prim::kPrimListGetItem}, // P.list_getitem, | |||||
| {"__setitem__", prim::kPrimListSetItem}, // P.list_setitem, | |||||
| {"__ms_iter__", prim::kPrimIdentity}, // P.identity | |||||
| {"__ms_next__", std::string("list_next")}, // C.list_next | |||||
| {"append", std::string("list_append")}, // C.list_next | |||||
| {"__bool__", std::string("list_bool")}, // C.list_bool | |||||
| {"__ms_hasnext__", std::string("list_hasnext")}, | |||||
| }}, | |||||
| {kObjectTypeDictionary, | |||||
| { | |||||
| {"__len__", prim::kPrimDictLen}, // P.dict_len | |||||
| {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem | |||||
| {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, | |||||
| {"__bool__", std::string("dict_bool")} // C.dict_bool | |||||
| }}, | |||||
| {kObjectTypeTensorType, | |||||
| { | |||||
| {"__add__", std::string("add")}, // C.add | |||||
| {"__sub__", std::string("sub")}, // C.sub | |||||
| {"__mul__", std::string("mul")}, // C.mul | |||||
| {"__truediv__", std::string("truediv")}, // C.truediv | |||||
| {"__floordiv__", std::string("floordiv")}, // C.floordiv | |||||
| {"__mod__", std::string("mod")}, // C.mod | |||||
| {"__pow__", std::string("pow_")}, // C.pow | |||||
| {"__floor__", std::string("array_floor")}, // C.array_floor | |||||
| {"__trunc__", std::string("array_trunc")}, // C.array_trunc | |||||
| {"__pos__", std::string("array_uadd")}, // C.array_uadd | |||||
| {"__neg__", std::string("array_usub")}, // C.array_usub | |||||
| {"__eq__", std::string("eq")}, // C.eq | |||||
| {"__ne__", std::string("ne")}, // C.ne | |||||
| {"__lt__", std::string("lt")}, // C.lt | |||||
| {"__gt__", std::string("gt")}, // C.gt | |||||
| {"__le__", std::string("le")}, // C.le | |||||
| {"__ge__", std::string("ge")}, // C.ge | |||||
| {"__matmul__", prim::kPrimDot}, // P.dot, | |||||
| {"__len__", prim::kPrimArrayLen}, // P.array_len, | |||||
| {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, | |||||
| {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, | |||||
| {"__ms_iter__", std::string("array_iter")}, // C.array_iter | |||||
| {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, | |||||
| {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, | |||||
| {"transpose", std::string("transpose")}, // P.transpose | |||||
| {"__bool__", std::string("tensor_bool")}, // C.tensor_bool | |||||
| {"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices | |||||
| }}, | |||||
| {kObjectTypeIndexedSlicesType, | |||||
| { | |||||
| {"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices | |||||
| {"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values | |||||
| {"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices | |||||
| {"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape | |||||
| }}, | |||||
| {kObjectTypeJTagged, {}}, | |||||
| {kObjectTypeSymbolicKeyType, {}}, | |||||
| {kObjectTypeEnvType, {}}}; | |||||
| return method_map; | return method_map; | ||||
| } | } | ||||
| @@ -30,6 +30,10 @@ bool AbstractBase::operator==(const AbstractBase &other) const { | |||||
| if (tid() != other.tid()) { | if (tid() != other.tid()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (BuildType()->type_id() == kObjectTypeUndeterminedType && | |||||
| other.BuildType()->type_id() == kObjectTypeUndeterminedType) { | |||||
| return true; | |||||
| } | |||||
| if (value_ == nullptr || other.value_ == nullptr) { | if (value_ == nullptr || other.value_ == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "If value_ is nullptr, AbstractBase::operator== should not be called. this: " | MS_LOG(EXCEPTION) << "If value_ is nullptr, AbstractBase::operator== should not be called. this: " | ||||
| << this->ToString() << ", other: " << other.ToString(); | << this->ToString() << ", other: " << other.ToString(); | ||||
| @@ -65,7 +69,7 @@ std::string AbstractBase::ToString() const { | |||||
| MS_EXCEPTION_IF_NULL(shape_); | MS_EXCEPTION_IF_NULL(shape_); | ||||
| buffer << type_name() << "(" | buffer << type_name() << "(" | ||||
| << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() | << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() | ||||
| << " sparse_grad: " << sparse_grad_ << ")"; | |||||
| << " sparse_grad: " << sparse_grad_ << " has_indexed_slices_grad: " << has_indexed_slices_grad_ << ")"; | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| @@ -76,6 +80,7 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { | |||||
| if (*this == *other) { | if (*this == *other) { | ||||
| auto ret = shared_from_base<AbstractBase>(); | auto ret = shared_from_base<AbstractBase>(); | ||||
| ret->set_sparse_grad(sparse_grad()); | ret->set_sparse_grad(sparse_grad()); | ||||
| ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| auto value_self = GetValueTrack(); | auto value_self = GetValueTrack(); | ||||
| @@ -85,10 +90,12 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { | |||||
| if (res_value == value_self) { | if (res_value == value_self) { | ||||
| auto ret = shared_from_base<AbstractBase>(); | auto ret = shared_from_base<AbstractBase>(); | ||||
| ret->set_sparse_grad(sparse_grad()); | ret->set_sparse_grad(sparse_grad()); | ||||
| ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| auto ret = std::make_shared<AbstractScalar>(res_value, res_type); | auto ret = std::make_shared<AbstractScalar>(res_value, res_type); | ||||
| ret->set_sparse_grad(sparse_grad()); | ret->set_sparse_grad(sparse_grad()); | ||||
| ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -409,6 +416,14 @@ std::size_t AbstractSlice::hash() const { | |||||
| return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()}); | return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()}); | ||||
| } | } | ||||
| ShapePtr AbstractUndetermined::shape() const { | |||||
| auto shp = dyn_cast<Shape>(GetShapeTrack()); | |||||
| if (shp == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Tensor should have a shape."; | |||||
| } | |||||
| return shp; | |||||
| } | |||||
| TypePtr AbstractTensor::BuildType() const { | TypePtr AbstractTensor::BuildType() const { | ||||
| MS_EXCEPTION_IF_NULL(element_); | MS_EXCEPTION_IF_NULL(element_); | ||||
| TypePtr element_type = element_->BuildType(); | TypePtr element_type = element_->BuildType(); | ||||
| @@ -425,6 +440,13 @@ BaseShapePtr AbstractTensor::BuildShape() const { | |||||
| } | } | ||||
| AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { | AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { | ||||
| if (other->BuildType()->type_id() == kObjectTypeUndeterminedType) { | |||||
| auto other_tensor = dyn_cast<AbstractUndetermined>(other); | |||||
| auto element = element_->Join(other_tensor->element()); | |||||
| auto shape = ShapeJoin(this->shape(), other_tensor->shape()); | |||||
| auto ret = std::make_shared<AbstractUndetermined>(element, shape); | |||||
| return ret; | |||||
| } | |||||
| auto other_tensor = dyn_cast<AbstractTensor>(other); | auto other_tensor = dyn_cast<AbstractTensor>(other); | ||||
| if (other_tensor == nullptr) { | if (other_tensor == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); | MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); | ||||
| @@ -433,6 +455,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { | |||||
| auto shape = ShapeJoin(this->shape(), other_tensor->shape()); | auto shape = ShapeJoin(this->shape(), other_tensor->shape()); | ||||
| auto ret = std::make_shared<AbstractTensor>(element, shape); | auto ret = std::make_shared<AbstractTensor>(element, shape); | ||||
| ret->set_sparse_grad(sparse_grad()); | ret->set_sparse_grad(sparse_grad()); | ||||
| ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -474,6 +497,7 @@ AbstractBasePtr AbstractTensor::Clone() const { | |||||
| clone->set_shape(shp->Clone()); | clone->set_shape(shp->Clone()); | ||||
| clone->set_value(GetValueTrack()); | clone->set_value(GetValueTrack()); | ||||
| clone->set_sparse_grad(sparse_grad()); | clone->set_sparse_grad(sparse_grad()); | ||||
| clone->set_has_indexed_slices_grad(has_indexed_slices_grad()); | |||||
| return clone; | return clone; | ||||
| } | } | ||||
| @@ -484,6 +508,7 @@ AbstractBasePtr AbstractTensor::Broaden() const { | |||||
| broaden->set_shape(shp->Clone()); | broaden->set_shape(shp->Clone()); | ||||
| broaden->set_value(kAnyValue); | broaden->set_value(kAnyValue); | ||||
| broaden->set_sparse_grad(sparse_grad()); | broaden->set_sparse_grad(sparse_grad()); | ||||
| broaden->set_has_indexed_slices_grad(has_indexed_slices_grad()); | |||||
| return broaden; | return broaden; | ||||
| } | } | ||||
| @@ -495,17 +520,10 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const { | |||||
| broaden->set_shape(shp); | broaden->set_shape(shp); | ||||
| broaden->set_value(kAnyValue); | broaden->set_value(kAnyValue); | ||||
| broaden->set_sparse_grad(sparse_grad()); | broaden->set_sparse_grad(sparse_grad()); | ||||
| broaden->set_has_indexed_slices_grad(has_indexed_slices_grad()); | |||||
| return broaden; | return broaden; | ||||
| } | } | ||||
| ShapePtr AbstractTensor::shape() const { | |||||
| auto shp = dyn_cast<Shape>(GetShapeTrack()); | |||||
| if (shp == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Tensor should have a shape."; | |||||
| } | |||||
| return shp; | |||||
| } | |||||
| std::string AbstractTensor::ToString() const { | std::string AbstractTensor::ToString() const { | ||||
| std::ostringstream buffer; | std::ostringstream buffer; | ||||
| BaseShapePtr shape_track = GetShapeTrack(); | BaseShapePtr shape_track = GetShapeTrack(); | ||||
| @@ -516,7 +534,7 @@ std::string AbstractTensor::ToString() const { | |||||
| buffer << type_name() << "(" | buffer << type_name() << "(" | ||||
| << "shape: " << shape_track->ToString() << ", element: " << element_->ToString() | << "shape: " << shape_track->ToString() << ", element: " << element_->ToString() | ||||
| << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad() | << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad() | ||||
| << ")"; | |||||
| << " has_indexed_slices_grad " << has_indexed_slices_grad() << ")"; | |||||
| return buffer.str(); | return buffer.str(); | ||||
| } | } | ||||
| @@ -1019,5 +1037,64 @@ std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &arg | |||||
| bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const { | bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const { | ||||
| return AbstractBasePtrListDeepEqual(lhs, rhs); | return AbstractBasePtrListDeepEqual(lhs, rhs); | ||||
| } | } | ||||
| // IndexedSlices | |||||
| TypePtr AbstractIndexedSlices::BuildType() const { | |||||
| MS_EXCEPTION_IF_NULL(element()); | |||||
| TypePtr element_type = element()->BuildType(); | |||||
| return std::make_shared<IndexedSlicesType>(element_type); | |||||
| } | |||||
| AbstractBasePtr AbstractIndexedSlices::Clone() const { | |||||
| MS_EXCEPTION_IF_NULL(element()); | |||||
| auto clone = std::make_shared<AbstractIndexedSlices>(element()->Clone()); | |||||
| ShapePtr shp = shape(); | |||||
| clone->set_shape(shp->Clone()); | |||||
| clone->set_value(GetValueTrack()); | |||||
| clone->set_indices(indices_->Clone()->cast<AbstractTensorPtr>()); | |||||
| clone->set_values(values_->Clone()->cast<AbstractTensorPtr>()); | |||||
| clone->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>()); | |||||
| return clone; | |||||
| } | |||||
| AbstractBasePtr AbstractIndexedSlices::Broaden() const { | |||||
| MS_EXCEPTION_IF_NULL(element()); | |||||
| auto broaden = std::make_shared<AbstractIndexedSlices>(element()->Broaden()); | |||||
| auto shp = shape(); | |||||
| broaden->set_shape(shp->Clone()); | |||||
| broaden->set_value(kAnyValue); | |||||
| broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>()); | |||||
| broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>()); | |||||
| broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>()); | |||||
| return broaden; | |||||
| } | |||||
| AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const { | |||||
| MS_EXCEPTION_IF_NULL(element()); | |||||
| auto broaden = std::make_shared<AbstractIndexedSlices>(element()->Broaden()); | |||||
| auto shp = shape()->Clone(); | |||||
| shp->Broaden(); | |||||
| broaden->set_shape(shp); | |||||
| broaden->set_value(kAnyValue); | |||||
| broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>()); | |||||
| broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>()); | |||||
| broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>()); | |||||
| return broaden; | |||||
| } | |||||
| std::string AbstractIndexedSlices::ToString() const { | |||||
| std::ostringstream buffer; | |||||
| BaseShapePtr shape_track = GetShapeTrack(); | |||||
| MS_EXCEPTION_IF_NULL(shape_track); | |||||
| MS_EXCEPTION_IF_NULL(element()); | |||||
| auto value_track = GetValueTrack(); | |||||
| MS_EXCEPTION_IF_NULL(value_track); | |||||
| buffer << type_name() << "(" | |||||
| << "shape: " << shape_track->ToString() << ", element: " << element()->ToString() | |||||
| << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")" | |||||
| << ", indices: " << indices_->ToString() << ", values" << values_->ToString() | |||||
| << ", dense_shape: " << dense_shape_->ToString(); | |||||
| return buffer.str(); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -44,7 +44,7 @@ class AbstractBase : public Base { | |||||
| public: | public: | ||||
| explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, | explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, | ||||
| const BaseShapePtr &shape = kNoShape) | const BaseShapePtr &shape = kNoShape) | ||||
| : value_(value), type_(type), shape_(shape), sparse_grad_("") {} | |||||
| : value_(value), type_(type), shape_(shape), sparse_grad_(""), has_indexed_slices_grad_(false) {} | |||||
| ~AbstractBase() override = default; | ~AbstractBase() override = default; | ||||
| MS_DECLARE_PARENT(AbstractBase, Base) | MS_DECLARE_PARENT(AbstractBase, Base) | ||||
| @@ -54,12 +54,16 @@ class AbstractBase : public Base { | |||||
| virtual bool operator==(const AbstractBase &other) const; | virtual bool operator==(const AbstractBase &other) const; | ||||
| void set_value(const ValuePtr &value) { value_ = value; } | void set_value(const ValuePtr &value) { value_ = value; } | ||||
| void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; } | void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; } | ||||
| void set_has_indexed_slices_grad(const bool &has_indexed_slices_grad) { | |||||
| has_indexed_slices_grad_ = has_indexed_slices_grad; | |||||
| } | |||||
| void set_type(const TypePtr &type) { type_ = type; } | void set_type(const TypePtr &type) { type_ = type; } | ||||
| void set_shape(const BaseShapePtr &shape) { shape_ = shape; } | void set_shape(const BaseShapePtr &shape) { shape_ = shape; } | ||||
| void set_value_desc(const std::string &desc) { value_desc_ = desc; } | void set_value_desc(const std::string &desc) { value_desc_ = desc; } | ||||
| const std::string &value_desc() const { return value_desc_; } | const std::string &value_desc() const { return value_desc_; } | ||||
| ValuePtr GetValueTrack() const { return value_; } | ValuePtr GetValueTrack() const { return value_; } | ||||
| const std::string &sparse_grad() const { return sparse_grad_; } | const std::string &sparse_grad() const { return sparse_grad_; } | ||||
| const bool &has_indexed_slices_grad() const { return has_indexed_slices_grad_; } | |||||
| TypePtr GetTypeTrack() const { return type_; } | TypePtr GetTypeTrack() const { return type_; } | ||||
| BaseShapePtr GetShapeTrack() const { return shape_; } | BaseShapePtr GetShapeTrack() const { return shape_; } | ||||
| @@ -88,6 +92,7 @@ class AbstractBase : public Base { | |||||
| BaseShapePtr shape_; | BaseShapePtr shape_; | ||||
| std::string value_desc_; // store initial value description for error report | std::string value_desc_; // store initial value description for error report | ||||
| std::string sparse_grad_; | std::string sparse_grad_; | ||||
| bool has_indexed_slices_grad_; | |||||
| }; | }; | ||||
| class AbstractScalar : public AbstractBase { | class AbstractScalar : public AbstractBase { | ||||
| @@ -231,35 +236,49 @@ class AbstractKeywordArg : public AbstractBase { | |||||
| }; | }; | ||||
| using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>; | using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>; | ||||
| class AbstractTensor : public AbstractBase { | |||||
| class AbstractUndetermined : public AbstractBase { | |||||
| public: | public: | ||||
| // shape and type are all unknown | |||||
| AbstractUndetermined() : AbstractBase(kAnyValue) {} | |||||
| // only element_ and value, shape track are valid member, type track are unknown. | // only element_ and value, shape track are valid member, type track are unknown. | ||||
| explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>()) | |||||
| explicit AbstractUndetermined(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>()) | |||||
| : AbstractBase(kAnyValue), element_(element) { | : AbstractBase(kAnyValue), element_(element) { | ||||
| if (element == nullptr) { | if (element == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "element is nullptr"; | MS_LOG(EXCEPTION) << "element is nullptr"; | ||||
| } | } | ||||
| if (element->isa<AbstractTensor>()) { | |||||
| if (element->isa<AbstractUndetermined>()) { | |||||
| MS_LOG(EXCEPTION) << "element type error"; | MS_LOG(EXCEPTION) << "element type error"; | ||||
| } | } | ||||
| set_shape(shape); | set_shape(shape); | ||||
| } | } | ||||
| AbstractTensor(const TypePtr &element_type, const std::vector<int> &shape) | |||||
| AbstractUndetermined(const TypePtr &element_type, const std::vector<int> &shape) | |||||
| : AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) { | : AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) { | ||||
| if (element_type == nullptr) { | if (element_type == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "element_type is nullptr"; | MS_LOG(EXCEPTION) << "element_type is nullptr"; | ||||
| } | } | ||||
| set_shape(std::make_shared<Shape>(shape)); | set_shape(std::make_shared<Shape>(shape)); | ||||
| } | } | ||||
| explicit AbstractTensor(const tensor::TensorPtr &tensor) | |||||
| : AbstractBase(tensor), element_(std::make_shared<AbstractScalar>(kAnyValue, tensor->Dtype())) { | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "tensor is nullptr"; | |||||
| } | |||||
| set_shape(std::make_shared<Shape>(tensor->shape())); | |||||
| } | |||||
| ~AbstractUndetermined() override = default; | |||||
| MS_DECLARE_PARENT(AbstractUndetermined, AbstractBase) | |||||
| TypePtr BuildType() const override { return std::make_shared<UndeterminedType>(); } | |||||
| AbstractBasePtr Clone() const override { return std::make_shared<AbstractUndetermined>(); } | |||||
| const AbstractBasePtr element() const { return element_; } | |||||
| ShapePtr shape() const; | |||||
| protected: | |||||
| AbstractBasePtr element_; | |||||
| }; | |||||
| class AbstractTensor : public AbstractUndetermined { | |||||
| public: | |||||
| // only element_ and value, shape track are valid member, type track are unknown. | |||||
| explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>()) | |||||
| : AbstractUndetermined(element, shape) {} | |||||
| AbstractTensor(const TypePtr &element_type, const std::vector<int> &shape) | |||||
| : AbstractUndetermined(element_type, shape) {} | |||||
| explicit AbstractTensor(const tensor::TensorPtr &tensor) : AbstractUndetermined(tensor->Dtype(), tensor->shape()) {} | |||||
| ~AbstractTensor() override = default; | ~AbstractTensor() override = default; | ||||
| MS_DECLARE_PARENT(AbstractTensor, AbstractBase) | |||||
| MS_DECLARE_PARENT(AbstractTensor, AbstractUndetermined) | |||||
| TypePtr BuildType() const override; | TypePtr BuildType() const override; | ||||
| BaseShapePtr BuildShape() const override; | BaseShapePtr BuildShape() const override; | ||||
| @@ -271,9 +290,7 @@ class AbstractTensor : public AbstractBase { | |||||
| bool operator==(const AbstractTensor &other) const; | bool operator==(const AbstractTensor &other) const; | ||||
| bool operator==(const AbstractBase &other) const override; | bool operator==(const AbstractBase &other) const override; | ||||
| ShapePtr shape() const; | |||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| const AbstractBasePtr element() const { return element_; } | |||||
| std::size_t hash() const override { | std::size_t hash() const override { | ||||
| auto value = GetValueTrack(); | auto value = GetValueTrack(); | ||||
| auto hash_sum = hash_combine(tid(), element_->hash()); | auto hash_sum = hash_combine(tid(), element_->hash()); | ||||
| @@ -285,9 +302,6 @@ class AbstractTensor : public AbstractBase { | |||||
| } | } | ||||
| return hash_sum; | return hash_sum; | ||||
| } | } | ||||
| private: | |||||
| AbstractBasePtr element_; | |||||
| }; | }; | ||||
| using AbstractTensorPtr = std::shared_ptr<AbstractTensor>; | using AbstractTensorPtr = std::shared_ptr<AbstractTensor>; | ||||
| using AbstractTensorPtrList = std::vector<AbstractTensorPtr>; | using AbstractTensorPtrList = std::vector<AbstractTensorPtr>; | ||||
| @@ -585,6 +599,35 @@ struct AbstractBasePtrListEqual { | |||||
| std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list); | std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list); | ||||
| bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs); | bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs); | ||||
| // IndexedSlices | |||||
| class AbstractIndexedSlices : public AbstractUndetermined { | |||||
| public: | |||||
| explicit AbstractIndexedSlices(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>()) | |||||
| : AbstractUndetermined(element, shape) {} | |||||
| AbstractIndexedSlices(const TypePtr &element_type, const std::vector<int> &shape) | |||||
| : AbstractUndetermined(element_type, shape) {} | |||||
| ~AbstractIndexedSlices() override = default; | |||||
| MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined) | |||||
| const AbstractTensorPtr indices() const { return indices_; } | |||||
| const AbstractTensorPtr values() const { return values_; } | |||||
| const AbstractTuplePtr dense_shape() const { return dense_shape_; } | |||||
| void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } | |||||
| void set_values(const AbstractTensorPtr &values) { values_ = values; } | |||||
| void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } | |||||
| TypePtr BuildType() const override; | |||||
| AbstractBasePtr Clone() const override; | |||||
| AbstractBasePtr Broaden() const override; | |||||
| AbstractBasePtr BroadenWithShape() const; | |||||
| std::string ToString() const override; | |||||
| private: | |||||
| AbstractTensorPtr indices_; | |||||
| AbstractTensorPtr values_; | |||||
| AbstractTuplePtr dense_shape_; | |||||
| }; | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_VALUE_H_ | #endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_VALUE_H_ | ||||
| @@ -58,6 +58,20 @@ class Evaluator : public Base { | |||||
| return args_spec_list; | return args_spec_list; | ||||
| } | } | ||||
| virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) { | |||||
| auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) { | |||||
| if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| }); | |||||
| if (is_abstract) { | |||||
| MS_LOG(DEBUG) << "Eval " << identifier_ << " return abstract result"; | |||||
| return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), std::make_shared<AttrValueMap>()); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| std::string ToString() const override { return identifier_; } | std::string ToString() const override { return identifier_; } | ||||
| virtual AnfNodePtr bound_node() const { return bound_node_.lock(); } | virtual AnfNodePtr bound_node() const { return bound_node_.lock(); } | ||||
| @@ -66,6 +66,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Function) | |||||
| ABSTRACT_REPORT_NAME_TRAITS(Type) | ABSTRACT_REPORT_NAME_TRAITS(Type) | ||||
| ABSTRACT_REPORT_NAME_TRAITS(KeywordArg) | ABSTRACT_REPORT_NAME_TRAITS(KeywordArg) | ||||
| ABSTRACT_REPORT_NAME_TRAITS(Class) | ABSTRACT_REPORT_NAME_TRAITS(Class) | ||||
| ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices) | |||||
| template <typename T> | template <typename T> | ||||
| std::shared_ptr<T> CheckArg(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) { | std::shared_ptr<T> CheckArg(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) { | ||||
| @@ -36,6 +36,7 @@ | |||||
| #include "pipeline/parse/resolve.h" | #include "pipeline/parse/resolve.h" | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #include "utils/convert_utils.h" | #include "utils/convert_utils.h" | ||||
| #include "utils/context/ms_context.h" | |||||
| #include "pipeline/parse/data_converter.h" | #include "pipeline/parse/data_converter.h" | ||||
| #include "pipeline/static_analysis/param_validator.h" | #include "pipeline/static_analysis/param_validator.h" | ||||
| #include "common/utils.h" | #include "common/utils.h" | ||||
| @@ -132,6 +133,12 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimControlDepend, {InferImplControlDepend, true}}, | {prim::kPrimControlDepend, {InferImplControlDepend, true}}, | ||||
| // Debug | // Debug | ||||
| {prim::kPrimDebug, {InferImplDebug, true}}, | {prim::kPrimDebug, {InferImplDebug, true}}, | ||||
| // IndexedSlices | |||||
| {prim::kPrimMakeIndexedSlices, {InferImplMakeIndexedSlices, true}}, | |||||
| {prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}}, | |||||
| {prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}}, | |||||
| {prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}}, | |||||
| {prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}}, | |||||
| }; | }; | ||||
| return prim_eval_implement_map; | return prim_eval_implement_map; | ||||
| } | } | ||||
| @@ -139,6 +146,16 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| using mindspore::parse::PyObjectWrapper; | using mindspore::parse::PyObjectWrapper; | ||||
| EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { | ||||
| auto context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context); | |||||
| bool enable_sparse_flag = context->enable_sparse_flag(); | |||||
| if (enable_sparse_flag && prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) { | |||||
| auto ret_abstract = AbstractEval(args); | |||||
| if (ret_abstract != nullptr) { | |||||
| MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; | |||||
| return ret_abstract; | |||||
| } | |||||
| } | |||||
| prim_->BeginRecordAddAttr(); | prim_->BeginRecordAddAttr(); | ||||
| AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); | AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); | ||||
| prim_->EndRecordAddAttr(); | prim_->EndRecordAddAttr(); | ||||
| @@ -485,6 +502,16 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic | |||||
| } // end anonymous namespace | } // end anonymous namespace | ||||
| EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | ||||
| auto context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context); | |||||
| bool enable_sparse_flag = context->enable_sparse_flag(); | |||||
| if (enable_sparse_flag) { | |||||
| auto ret_abstract = AbstractEval(args); | |||||
| if (ret_abstract != nullptr) { | |||||
| MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined"; | |||||
| return ret_abstract; | |||||
| } | |||||
| } | |||||
| MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); | MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); | ||||
| const auto &iter = cache_->find(args); | const auto &iter = cache_->find(args); | ||||
| @@ -512,6 +539,16 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs | |||||
| } | } | ||||
| EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { | ||||
| auto context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context); | |||||
| bool enable_sparse_flag = context->enable_sparse_flag(); | |||||
| if (enable_sparse_flag) { | |||||
| auto ret_abstract = AbstractEval(args); | |||||
| if (ret_abstract != nullptr) { | |||||
| MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined"; | |||||
| return ret_abstract; | |||||
| } | |||||
| } | |||||
| // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. | // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. | ||||
| if (nargs_ != args.size()) { | if (nargs_ != args.size()) { | ||||
| MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs"; | MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs"; | ||||
| @@ -871,6 +908,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||||
| auto ref_value = ref_abs->ref(); | auto ref_value = ref_abs->ref(); | ||||
| MS_EXCEPTION_IF_NULL(ref_value); | MS_EXCEPTION_IF_NULL(ref_value); | ||||
| ret->set_sparse_grad(ref_value->sparse_grad()); | ret->set_sparse_grad(ref_value->sparse_grad()); | ||||
| ret->set_has_indexed_slices_grad(ref_value->has_indexed_slices_grad()); | |||||
| return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>()); | return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>()); | ||||
| } | } | ||||
| @@ -886,6 +924,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { | |||||
| std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x); | std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x); | ||||
| std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type); | std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type); | ||||
| abs_scalar->set_sparse_grad(x->sparse_grad()); | abs_scalar->set_sparse_grad(x->sparse_grad()); | ||||
| abs_scalar->set_has_indexed_slices_grad(x->has_indexed_slices_grad()); | |||||
| return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>()); | return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>()); | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -897,6 +936,16 @@ class GetAttrEvaluator : public TransitionPrimEvaluator { | |||||
| MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); | MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); | ||||
| EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, | ||||
| const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { | const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { | ||||
| auto context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context); | |||||
| bool enable_sparse_flag = context->enable_sparse_flag(); | |||||
| if (enable_sparse_flag) { | |||||
| auto ret_abstract = AbstractEval(args_spec_list); | |||||
| if (ret_abstract != nullptr) { | |||||
| MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined"; | |||||
| return ret_abstract; | |||||
| } | |||||
| } | |||||
| // Inputs: data, item | // Inputs: data, item | ||||
| if (args_spec_list.size() != 2) { | if (args_spec_list.size() != 2) { | ||||
| MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); | MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); | ||||
| @@ -350,6 +350,17 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv | |||||
| AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| void InitUndeterminedFromEnv(const std::string &sparse_shape_types); | void InitUndeterminedFromEnv(const std::string &sparse_shape_types); | ||||
| AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -228,6 +228,10 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf | |||||
| MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString() | MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString() | ||||
| << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); | << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); | ||||
| } | } | ||||
| if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) { | |||||
| MS_LOG(DEBUG) << "EvalCNode eval Undetermined"; | |||||
| return std::make_shared<EvalResult>(maybe_func->Clone(), std::make_shared<AttrValueMap>()); | |||||
| } | |||||
| AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func); | AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func); | ||||
| if (func == nullptr) { | if (func == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return not AbstractFunction: " << maybe_func->ToString() | MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return not AbstractFunction: " << maybe_func->ToString() | ||||
| @@ -32,6 +32,7 @@ using mindspore::abstract::AbstractBase; | |||||
| using mindspore::abstract::AbstractClass; | using mindspore::abstract::AbstractClass; | ||||
| using mindspore::abstract::AbstractError; | using mindspore::abstract::AbstractError; | ||||
| using mindspore::abstract::AbstractFunction; | using mindspore::abstract::AbstractFunction; | ||||
| using mindspore::abstract::AbstractIndexedSlices; | |||||
| using mindspore::abstract::AbstractJTagged; | using mindspore::abstract::AbstractJTagged; | ||||
| using mindspore::abstract::AbstractList; | using mindspore::abstract::AbstractList; | ||||
| using mindspore::abstract::AbstractScalar; | using mindspore::abstract::AbstractScalar; | ||||
| @@ -93,7 +94,8 @@ void ValidateAbstract(const AnfNodePtr &node) { | |||||
| } | } | ||||
| if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() || | if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() || | ||||
| ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) { | |||||
| ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() || | |||||
| ptrBase->isa<abstract::AbstractRefKey>()) { | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -89,6 +89,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { | |||||
| max_device_memory_ = kDefaultMaxDeviceMemory; | max_device_memory_ = kDefaultMaxDeviceMemory; | ||||
| print_file_path_ = ""; | print_file_path_ = ""; | ||||
| enable_graph_kernel_ = false; | enable_graph_kernel_ = false; | ||||
| enable_sparse_flag_ = false; | |||||
| } | } | ||||
| std::shared_ptr<MsContext> MsContext::GetInstance() { | std::shared_ptr<MsContext> MsContext::GetInstance() { | ||||
| @@ -161,6 +161,9 @@ class MsContext { | |||||
| void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; } | void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; } | ||||
| bool enable_graph_kernel() const { return enable_graph_kernel_; } | bool enable_graph_kernel() const { return enable_graph_kernel_; } | ||||
| bool enable_sparse_flag() const { return enable_sparse_flag_; } | |||||
| void set_enable_sparse_flag(bool enable_sparse_flag) { enable_sparse_flag_ = enable_sparse_flag; } | |||||
| private: | private: | ||||
| MsContext(const std::string &backend_policy, const std::string &target); | MsContext(const std::string &backend_policy, const std::string &target); | ||||
| void GetGeOptions(std::map<std::string, std::string> *ge_options) const; | void GetGeOptions(std::map<std::string, std::string> *ge_options) const; | ||||
| @@ -204,6 +207,7 @@ class MsContext { | |||||
| float max_device_memory_; | float max_device_memory_; | ||||
| std::string print_file_path_; | std::string print_file_path_; | ||||
| bool enable_graph_kernel_; | bool enable_graph_kernel_; | ||||
| bool enable_sparse_flag_; | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,10 +17,10 @@ from . import dtype | |||||
| from .api import ms_function | from .api import ms_function | ||||
| from .dtype import * | from .dtype import * | ||||
| from .parameter import Parameter, ParameterTuple | from .parameter import Parameter, ParameterTuple | ||||
| from .tensor import MetaTensor, Tensor | |||||
| from .tensor import MetaTensor, Tensor, IndexedSlices | |||||
| __all__ = [ | __all__ = [ | ||||
| "MetaTensor", "Tensor", # tensor | |||||
| "MetaTensor", "Tensor", "IndexedSlices", # tensor | |||||
| 'ms_function', # api | 'ms_function', # api | ||||
| 'Parameter', 'ParameterTuple', # parameter | 'Parameter', 'ParameterTuple', # parameter | ||||
| "dtype" | "dtype" | ||||
| @@ -52,13 +52,16 @@ class Parameter: | |||||
| layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode, | layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode, | ||||
| broadcast and gradients communication would not be applied on parameters. Default: False. | broadcast and gradients communication would not be applied on parameters. Default: False. | ||||
| sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty. | sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty. | ||||
| has_indexed_slices (bool): Set if the parameter's gradient is indexed_slices. Default: false. | |||||
| """ | """ | ||||
| def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, sparse_grad=""): | |||||
| def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, | |||||
| sparse_grad="", has_indexed_slices_grad=False): | |||||
| self.set_parameter_data(default_input) | self.set_parameter_data(default_input) | ||||
| self.name = name | self.name = name | ||||
| self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
| self.layerwise_parallel = layerwise_parallel | self.layerwise_parallel = layerwise_parallel | ||||
| self.sparse_grad = sparse_grad | self.sparse_grad = sparse_grad | ||||
| self.has_indexed_slices_grad = has_indexed_slices_grad | |||||
| self._is_init = False | self._is_init = False | ||||
| self._sliced = False | self._sliced = False | ||||
| self.clone_info = _CloneInfo() | self.clone_info = _CloneInfo() | ||||
| @@ -186,6 +189,17 @@ class Parameter: | |||||
| raise TypeError("`sparse_grad` parameter must be str type") | raise TypeError("`sparse_grad` parameter must be str type") | ||||
| self._sparse_grad = value | self._sparse_grad = value | ||||
| @property | |||||
| def has_indexed_slices_grad(self): | |||||
| """Return whether the parameter's gradient is indexed_slices.""" | |||||
| return self._has_indexed_slices_grad | |||||
| @has_indexed_slices_grad.setter | |||||
| def has_indexed_slices_grad(self, value=False): | |||||
| if not isinstance(value, bool): | |||||
| raise TypeError("`has_indexed_slices_grad` parameter must be bool type") | |||||
| self._has_indexed_slices_grad = value | |||||
| @property | @property | ||||
| def data(self): | def data(self): | ||||
| return self.default_input | return self.default_input | ||||
| @@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename | |||||
| from . import dtype as mstype | from . import dtype as mstype | ||||
| from ._register_for_tensor import tensor_operator_registry | from ._register_for_tensor import tensor_operator_registry | ||||
| __all__ = ['Tensor', 'MetaTensor'] | |||||
| __all__ = ['Tensor', 'MetaTensor', 'IndexedSlices'] | |||||
| np_types = (np.int8, np.int16, np.int32, np.int64, | np_types = (np.int8, np.int16, np.int32, np.int64, | ||||
| np.uint8, np.uint16, np.uint32, np.uint64, np.float16, | np.uint8, np.uint16, np.uint32, np.uint64, np.float16, | ||||
| np.float32, np.float64, np.bool_) | np.float32, np.float64, np.bool_) | ||||
| @@ -214,3 +214,8 @@ class Tensor(Tensor_): | |||||
| raise TypeError("init_flag must be bool.") | raise TypeError("init_flag must be bool.") | ||||
| self.set_init_flag(value) | self.set_init_flag(value) | ||||
| self._init_flag = value | self._init_flag = value | ||||
| class IndexedSlices: | |||||
| def __init__(self, indices, values, dense_shape): | |||||
| raise NotImplementedError | |||||
| @@ -355,6 +355,14 @@ class _Context: | |||||
| def check_bprop(self, check_bprop_flag): | def check_bprop(self, check_bprop_flag): | ||||
| self._context_handle.set_check_bprop_flag(check_bprop_flag) | self._context_handle.set_check_bprop_flag(check_bprop_flag) | ||||
| @property | |||||
| def enable_sparse(self): | |||||
| return self._context_handle.get_enable_sparse_flag() | |||||
| @enable_sparse.setter | |||||
| def enable_sparse(self, enable_sparse_flag): | |||||
| self._context_handle.set_enable_sparse_flag(enable_sparse_flag) | |||||
| @property | @property | ||||
| def max_device_memory(self): | def max_device_memory(self): | ||||
| return self._context_handle.get_max_device_memory() | return self._context_handle.get_max_device_memory() | ||||
| @@ -510,7 +518,8 @@ def reset_auto_parallel_context(): | |||||
| save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool, | save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool, | ||||
| save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, | save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, | ||||
| enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, | enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, | ||||
| enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str) | |||||
| enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str, | |||||
| enable_sparse=bool) | |||||
| def set_context(**kwargs): | def set_context(**kwargs): | ||||
| """ | """ | ||||
| Sets context for running environment. | Sets context for running environment. | ||||
| @@ -567,6 +576,7 @@ def set_context(**kwargs): | |||||
| The format is "xxGB". Default: "1024GB". | The format is "xxGB". Default: "1024GB". | ||||
| print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to | print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to | ||||
| a file by default, and turn off printing to the screen. | a file by default, and turn off printing to the screen. | ||||
| enable_sparse (bool): Whether to enable sparse feature. Default: False. | |||||
| Raises: | Raises: | ||||
| ValueError: If input key is not an attribute in context. | ValueError: If input key is not an attribute in context. | ||||
| @@ -153,6 +153,14 @@ shape_mul = Primitive("shape_mul") | |||||
| # a primitive to compare between tuple. | # a primitive to compare between tuple. | ||||
| stop_gradient = Primitive("stop_gradient") | stop_gradient = Primitive("stop_gradient") | ||||
| make_indexed_slices = Primitive('MakeIndexedSlices') | |||||
| indexed_slices_get_values = Primitive('IndexedSlicesGetValues') | |||||
| indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') | |||||
| indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') | |||||
| is_indexed_slices = Primitive('IsIndexedSlices') | |||||
| tensor_operator_registry.register('__add__', tensor_add) | tensor_operator_registry.register('__add__', tensor_add) | ||||
| tensor_operator_registry.register('__sub__', tensor_sub) | tensor_operator_registry.register('__sub__', tensor_sub) | ||||
| tensor_operator_registry.register('__mul__', tensor_mul) | tensor_operator_registry.register('__mul__', tensor_mul) | ||||
| @@ -564,7 +564,7 @@ class SparseGatherV2(GatherV2): | |||||
| >>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32) | >>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32) | ||||
| >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32) | >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32) | ||||
| >>> axis = 1 | >>> axis = 1 | ||||
| >>> out = P.GatherV2()(input_params, input_indices, axis) | |||||
| >>> out = P.SparseGatherV2()(input_params, input_indices, axis) | |||||
| """ | """ | ||||
| @@ -603,5 +603,18 @@ TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { | |||||
| ASSERT_TRUE(CheckOpt(before2l, after2, patterns)); | ASSERT_TRUE(CheckOpt(before2l, after2, patterns)); | ||||
| ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); | ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); | ||||
| } | } | ||||
| TEST_F(TestOptLib, test_indexed_slices) { | |||||
| FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_indices"); | |||||
| FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_indices"); | |||||
| FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_values"); | |||||
| FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_values"); | |||||
| FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_dense_shape"); | |||||
| FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_dense_shape"); | |||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.indexed_slices_eliminate_}); | |||||
| ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns)); | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1130,3 +1130,38 @@ def test_adjust_allreduce_mul_add(tag): | |||||
| return Mul(AllReduce(AddN((Mul(z, z), x))), y) | return Mul(AllReduce(AddN((Mul(z, z), x))), y) | ||||
| return fns[tag] | return fns[tag] | ||||
| def test_indexed_slices(tag): | |||||
| """ test_add_zero """ | |||||
| fns = FnDict() | |||||
| make_indexed_slices = Primitive('MakeIndexedSlices') | |||||
| indexed_slices_get_values = Primitive('IndexedSlicesGetValues') | |||||
| indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') | |||||
| indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') | |||||
| @fns | |||||
| def before_get_indices(x, y, z): | |||||
| return indexed_slices_get_indices(make_indexed_slices(x, y, z)) | |||||
| @fns | |||||
| def after_get_indices(x, y, z): | |||||
| return x | |||||
| @fns | |||||
| def before_get_values(x, y, z): | |||||
| return indexed_slices_get_values(make_indexed_slices(x, y, z)) | |||||
| @fns | |||||
| def after_get_values(x, y, z): | |||||
| return y | |||||
| @fns | |||||
| def before_get_dense_shape(x, y, z): | |||||
| return indexed_slices_get_dense_shape(make_indexed_slices(x, y, z)) | |||||
| @fns | |||||
| def after_get_dense_shape(x, y, z): | |||||
| return z | |||||
| return fns[tag] | |||||
| @@ -0,0 +1,290 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| @File : test_indexed_slices.py | |||||
| @Author: | |||||
| @Date : 2020-06-08 | |||||
| @Desc : test mindspore indexed_slices's operation | |||||
| """ | |||||
| import numpy as np | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like | |||||
| from mindspore.ops.primitive import constexpr | |||||
| from mindspore.ops._grad.grad_base import bprop_getters | |||||
| from mindspore import Tensor, IndexedSlices, context | |||||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._checkparam import Rel | |||||
| from mindspore.nn import Optimizer | |||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||||
| reduce_sum = P.ReduceSum() | |||||
| unsorted_segment_sum = P.UnsortedSegmentSum() | |||||
| transpose = P.Transpose() | |||||
| shape_op = P.Shape() | |||||
| reshape = P.Reshape() | |||||
| size_op = P.Size() | |||||
| invert_permutation = P.InvertPermutation() | |||||
| logical_and = P.LogicalAnd() | |||||
| context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) | |||||
| @constexpr | |||||
| def _generate_shape_index(out_shape, indices_shape, axis): | |||||
| out_rank = len(out_shape) | |||||
| ind_rank = len(indices_shape) | |||||
| if axis < 0: | |||||
| axis += out_rank - ind_rank + 1 | |||||
| perm_part1 = tuple(range(axis, axis + ind_rank)) | |||||
| index = tuple(range(out_rank)) | |||||
| perm = perm_part1 + index[:axis] + index[axis + ind_rank:] | |||||
| return perm | |||||
| @constexpr | |||||
| def _generate_inverse_index(x_shape, axis): | |||||
| x_rank = len(x_shape) | |||||
| index = tuple(range(x_rank)) | |||||
| if axis < 0: | |||||
| axis += x_rank | |||||
| perm = index[1:1 + axis] + (0,) + index[1 + axis:] | |||||
| return perm | |||||
| class MySparseGatherV2(P.GatherV2): | |||||
| """ | |||||
| For test | |||||
| """ | |||||
| @bprop_getters.register(MySparseGatherV2) | |||||
| def get_bprop_sparse_gather_v2(self): | |||||
| """Generate bprop for MySparseGatherV2""" | |||||
| def bprop(x, indices, axis, out, dout): | |||||
| x_shp = shape_op(x) | |||||
| if axis == 0: | |||||
| indices_size = (size_op(indices),) | |||||
| x_tail_shp = x_shp[1:] | |||||
| values_shape = indices_size + x_tail_shp | |||||
| values = reshape(dout, values_shape) | |||||
| indices = reshape(indices, indices_size) | |||||
| return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis) | |||||
| if F.rank(dout) == 0: | |||||
| dout = P.ExpandDims()(dout, -1) | |||||
| if F.rank(indices) == 0: | |||||
| indices = P.ExpandDims()(indices, -1) | |||||
| out_shp = shape_op(dout) | |||||
| ind_shp = shape_op(indices) | |||||
| # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) | |||||
| perm_1 = _generate_shape_index(out_shp, ind_shp, axis) | |||||
| values_transpose = transpose(dout, perm_1) | |||||
| params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) | |||||
| # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) | |||||
| perm_2 = _generate_inverse_index(x_shp, axis) | |||||
| params_grad = transpose(params_grad, perm_2) | |||||
| return params_grad, zeros_like(indices), zeros_like(axis) | |||||
| return bprop | |||||
| adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") | |||||
| @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||||
| "Tensor", "Tensor", "Tensor", "Undetermined", "Bool") | |||||
| def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): | |||||
| if gradient.is_indexed_slices(): | |||||
| return gradient.values() | |||||
| op_mul = P.Mul() | |||||
| op_square = P.Square() | |||||
| op_sqrt = P.Sqrt() | |||||
| op_cast = P.Cast() | |||||
| op_reshape = P.Reshape() | |||||
| op_shape = P.Shape() | |||||
| param_fp32 = op_cast(param, mstype.float32) | |||||
| m_fp32 = op_cast(m, mstype.float32) | |||||
| v_fp32 = op_cast(v, mstype.float32) | |||||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||||
| next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) | |||||
| next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) | |||||
| - beta2, op_square(gradient_fp32)) | |||||
| update = next_m / (op_sqrt(next_v) + eps) | |||||
| if decay_flag: | |||||
| update = update + op_mul(weight_decay_tensor, param_fp32) | |||||
| update_with_lr = op_mul(lr, update) | |||||
| next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) | |||||
| next_v = F.depend(next_v, F.assign(param, next_param)) | |||||
| next_v = F.depend(next_v, F.assign(m, next_m)) | |||||
| next_v = F.depend(next_v, F.assign(v, next_v)) | |||||
| return next_v | |||||
| def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): | |||||
| """Check the type of inputs.""" | |||||
| validator.check_value_type("beta1", beta1, [float], prim_name) | |||||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||||
| validator.check_value_type("eps", eps, [float], prim_name) | |||||
| validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | |||||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||||
| class AdamWeightDecaySparse(Optimizer): | |||||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, | |||||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||||
| super(AdamWeightDecaySparse, self).__init__(learning_rate, params) | |||||
| if self.is_group: | |||||
| raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") | |||||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||||
| self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) | |||||
| self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) | |||||
| self.eps = Tensor(np.array([eps]).astype(np.float32)) | |||||
| self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) | |||||
| self.params = self.parameters | |||||
| self.moments1 = self.params.clone(prefix="adam_m", init='zeros') | |||||
| self.moments2 = self.params.clone(prefix="adam_v", init='zeros') | |||||
| self.decay_flag = tuple(decay_filter(x) for x in self.params) | |||||
| self.map = C.Map() | |||||
| def construct(self, gradients): | |||||
| lr = self.get_lr() | |||||
| updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr, | |||||
| self.weight_decay_tensor), | |||||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||||
| return updated_velocity | |||||
| def test_indexed_slices_make_indexed_slices(): | |||||
| class MakeIndexedSlices(nn.Cell): | |||||
| def __init__(self): | |||||
| super(MakeIndexedSlices, self).__init__() | |||||
| self.dense_shape = (3, 4) | |||||
| def construct(self, indices, values): | |||||
| ret = (IndexedSlices(indices, values, self.dense_shape),) | |||||
| return ret[0].is_indexed_slices() | |||||
| indices = Tensor([[0, 0], [1, 2]]) | |||||
| values = Tensor([1, 2], dtype=ms.float32) | |||||
| MakeIndexedSlices()(indices, values) | |||||
| def test_indexed_slices_attr(): | |||||
| class IndexedSlicesGetAttr(nn.Cell): | |||||
| def __init__(self): | |||||
| super(IndexedSlicesGetAttr, self).__init__() | |||||
| self.dense_shape = (3, 4) | |||||
| def construct(self, indices, values): | |||||
| x = IndexedSlices(indices, values, self.dense_shape) | |||||
| return x.values(), x.indices(), x.dense_shape() | |||||
| indices = Tensor([[0, 0], [1, 2]]) | |||||
| values = Tensor([1, 2], dtype=ms.float32) | |||||
| IndexedSlicesGetAttr()(indices, values) | |||||
| def test_indexed_slices_sparse_gatherv2_grad_all(): | |||||
| grad_all = C.GradOperation('get_all', get_all=True) | |||||
| class GradWrap(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(GradWrap, self).__init__() | |||||
| self.network = network | |||||
| def construct(self, x, y): | |||||
| grad = grad_all(self.network)(x, y) | |||||
| return grad, grad[0].is_indexed_slices(), grad[1].is_indexed_slices() | |||||
| class SparseGatherV2(nn.Cell): | |||||
| def __init__(self): | |||||
| super(SparseGatherV2, self).__init__() | |||||
| self.sparse_gatherv2 = MySparseGatherV2() | |||||
| self.axis = 0 | |||||
| def construct(self, params, indices): | |||||
| return self.sparse_gatherv2(params, indices, self.axis) | |||||
| params = Tensor(np.ones([3, 1, 2]).astype(np.int32)) | |||||
| indices = Tensor(np.array([0, 1]).astype(np.int32)) | |||||
| GradWrap(SparseGatherV2())(params, indices) | |||||
| def test_indexed_slices_sparse_gatherv2_grad_with_pram(): | |||||
| grad_by_list = C.GradOperation('get_by_list', get_by_list=True) | |||||
| class GradWrap(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(GradWrap, self).__init__() | |||||
| self.network = network | |||||
| self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) | |||||
| def construct(self, x): | |||||
| weights = self.weights | |||||
| grad = grad_by_list(self.network, weights)(x) | |||||
| x = grad[0] | |||||
| return x.is_indexed_slices(), x.values(), x.indices(), x.dense_shape() | |||||
| class SparseGatherV2(nn.Cell): | |||||
| def __init__(self): | |||||
| super(SparseGatherV2, self).__init__() | |||||
| self.sparse_gatherv2 = MySparseGatherV2() | |||||
| self.axis = 0 | |||||
| self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), | |||||
| name="params", has_indexed_slices_grad=True) | |||||
| def construct(self, indices): | |||||
| return self.sparse_gatherv2(self.params, indices, self.axis) | |||||
| indices = Tensor(np.array([0, 1]).astype(np.int32)) | |||||
| network = GradWrap(SparseGatherV2()) | |||||
| network(indices) | |||||
| def test_indexed_slices_is_indexed_slices(): | |||||
| class MakeIndexedSlices(nn.Cell): | |||||
| def __init__(self): | |||||
| super(MakeIndexedSlices, self).__init__() | |||||
| self.dense_shape = (3, 4) | |||||
| def construct(self, indices, values): | |||||
| indexed_slices = IndexedSlices(indices, values, self.dense_shape) | |||||
| ret = indexed_slices.is_indexed_slices() | |||||
| return ret | |||||
| indices = Tensor([[0, 0], [1, 2]]) | |||||
| values = Tensor([1, 2], dtype=ms.float32) | |||||
| MakeIndexedSlices()(indices, values) | |||||
| def test_indexed_slices_env_get(): | |||||
| class Loss(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Loss, self).__init__() | |||||
| def construct(self, base, target): | |||||
| return base | |||||
| class NetWithSparseGatherV2(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetWithSparseGatherV2, self).__init__() | |||||
| self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", has_indexed_slices_grad=True) | |||||
| self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") | |||||
| self.gatherv2 = MySparseGatherV2() | |||||
| self.axis = 0 | |||||
| def construct(self, indices): | |||||
| return self.gatherv2(self.w1, indices, self.axis) * self.w2 | |||||
| inputs = Tensor(np.array([0, 1]).astype(np.int32)) | |||||
| label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) | |||||
| net = NetWithSparseGatherV2() | |||||
| net.set_train() | |||||
| loss = Loss() | |||||
| optimizer = AdamWeightDecaySparse(net.trainable_params()) | |||||
| net_with_loss = WithLossCell(net, loss) | |||||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||||
| train_network(inputs, label) | |||||
| @@ -155,7 +155,7 @@ def test_AdamWeightDecaySparse(): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(NetWithSparseGatherV2, self).__init__() | super(NetWithSparseGatherV2, self).__init__() | ||||
| self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad="sparse_key_w1") | self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad="sparse_key_w1") | ||||
| self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2", sparse_grad="sparse_key_w2") | |||||
| self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") | |||||
| self.gatherv2 = P.SparseGatherV2() | self.gatherv2 = P.SparseGatherV2() | ||||
| self.axis = 0 | self.axis = 0 | ||||
| def construct(self, indices): | def construct(self, indices): | ||||