Merge pull request !1189 from leopz/mastertags/v0.3.0-alpha
| @@ -26,6 +26,7 @@ | |||
| #include "utils/graph_utils.h" | |||
| #include "utils/symbolic.h" | |||
| #include "ir/meta_func_graph.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "pipeline/parse/python_adapter.h" | |||
| #include "pipeline/parse/resolve.h" | |||
| #include "operator/composite/composite.h" | |||
| @@ -469,7 +470,8 @@ void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNode | |||
| MS_LOG(EXCEPTION) << "Param could not cast to parameter"; | |||
| } | |||
| if (param_ptr->has_default()) { | |||
| ofs << " = @" << DumpObject(param_ptr->default_param(), "D"); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_ptr->default_param()); | |||
| ofs << " = @" << DumpObject(param_value->value(), "D"); | |||
| } | |||
| // output comment | |||
| @@ -1650,7 +1652,8 @@ class IrParser { | |||
| // load parameter default value from serialized file | |||
| py::object default_obj = LoadObject(lexer_.GetTokenText()); | |||
| param->set_default_param(default_obj); | |||
| auto param_value_new = std::make_shared<ParamValuePy>(default_obj); | |||
| param->set_default_param(param_value_new); | |||
| tok = lexer_.GetNextToken(); | |||
| } | |||
| @@ -21,12 +21,17 @@ | |||
| #include <vector> | |||
| #include <string> | |||
| #include "pybind11/pybind11.h" | |||
| #include "ir/meta_func_graph.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "utils/utils.h" | |||
| #include "operator/composite/composite.h" | |||
| #include "ir/meta_tensor.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| // namespace to support debug utils | |||
| @@ -312,17 +317,21 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) { | |||
| for (auto ¶meter : key->parameters()) { | |||
| buffer_ << "<tr><td>"; | |||
| buffer_ << parameter->ToString(); | |||
| auto py_p = dyn_cast<Parameter>(parameter)->default_param(); | |||
| if (py::hasattr(py_p, "default_input")) { | |||
| py_p = py_p.attr("default_input"); | |||
| if (py::hasattr(py_p, PYTHON_TENSOR_FLAG)) { | |||
| std::shared_ptr<tensor::Tensor> m_tensor = py_p.cast<std::shared_ptr<tensor::Tensor>>(); | |||
| py::tuple shape = m_tensor->GetPyTupleShape(); | |||
| buffer_ << "[" << std::string(py::str(shape)) << "]"; | |||
| } else if (py::hasattr(py_p, PYTHON_META_TENSOR_FLAG)) { | |||
| std::shared_ptr<tensor::MetaTensor> m_tensor = py_p.cast<std::shared_ptr<tensor::MetaTensor>>(); | |||
| py::tuple shape = m_tensor->GetPyTupleShape(); | |||
| buffer_ << "[" << std::string(py::str(shape)) << "]"; | |||
| auto param = parameter->cast<ParameterPtr>(); | |||
| if (param->has_default()) { | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param->default_param()); | |||
| auto py_p = param_value->value(); | |||
| if (py::hasattr(py_p, "default_input")) { | |||
| py_p = py_p.attr("default_input"); | |||
| if (py::hasattr(py_p, PYTHON_TENSOR_FLAG)) { | |||
| auto m_tensor = py_p.cast<std::shared_ptr<tensor::Tensor>>(); | |||
| py::tuple shape = m_tensor->GetPyTupleShape(); | |||
| buffer_ << "[" << py::str(shape) << "]"; | |||
| } else if (py::hasattr(py_p, PYTHON_META_TENSOR_FLAG)) { | |||
| auto m_tensor = py_p.cast<std::shared_ptr<tensor::MetaTensor>>(); | |||
| py::tuple shape = m_tensor->GetPyTupleShape(); | |||
| buffer_ << "[" << py::str(shape) << "]"; | |||
| } | |||
| } | |||
| } | |||
| buffer_ << "</td></tr>"; | |||
| @@ -18,9 +18,9 @@ | |||
| #include <utility> | |||
| #include <fstream> | |||
| #include <sstream> | |||
| #include <climits> | |||
| #include "ir/anf.h" | |||
| #include "pipeline/parse/parse.h" | |||
| #include "pipeline/parse/python_adapter.h" | |||
| #include "utils/convert_utils.h" | |||
| namespace mindspore { | |||
| std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) { | |||
| @@ -24,13 +24,10 @@ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "pybind11/pybind11.h" | |||
| #include "ir/base.h" | |||
| #include "debug/trace_info.h" | |||
| namespace mindspore { | |||
| namespace py = pybind11; | |||
| // namespace to support intermediate representation definition | |||
| enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSourceLineTipInLine = 2 }; | |||
| @@ -21,7 +21,6 @@ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "utils/any.h" | |||
| #include "ir/anf.h" | |||
| namespace mindspore { | |||
| @@ -19,8 +19,6 @@ | |||
| #include <fstream> | |||
| #include <sstream> | |||
| #include "ir/anf.h" | |||
| #include "pipeline/parse/parse.h" | |||
| #include "pipeline/parse/python_adapter.h" | |||
| namespace mindspore { | |||
| std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) { | |||
| @@ -24,12 +24,9 @@ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "pybind11/pybind11.h" | |||
| #include "ir/base.h" | |||
| namespace mindspore { | |||
| namespace py = pybind11; | |||
| class TraceInfo; | |||
| using TraceInfoPtr = std::shared_ptr<TraceInfo>; | |||
| class Location; | |||
| @@ -23,21 +23,11 @@ | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "ir/visitor.h" | |||
| #include "pipeline/static_analysis/static_analysis.h" | |||
| #include "operator/ops.h" | |||
| #include "parallel/ops_info/ops_utils.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore { | |||
| // namespace to support intermediate representation definition | |||
| // Methods of AnfNode | |||
| TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); } | |||
| BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } | |||
| std::string AnfNode::ToString() const { | |||
| return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info()); | |||
| } | |||
| CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph) | |||
| : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} | |||
| @@ -85,66 +75,6 @@ std::string CNode::DebugString(int recursive_level) const { | |||
| return buffer.str(); | |||
| } | |||
| OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { | |||
| if (operator_info_ != nullptr) { | |||
| MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() | |||
| << ", using the new one: " << operator_info->name(); | |||
| auto old_ptr = operator_info_; | |||
| operator_info_ = operator_info; | |||
| return old_ptr; | |||
| } | |||
| operator_info_ = operator_info; | |||
| return nullptr; | |||
| } | |||
| std::string CNode::fullname_with_scope() { | |||
| // if full name is set, return its name immediately | |||
| if (!fullname_with_scope_.empty()) { | |||
| return fullname_with_scope_; | |||
| } | |||
| if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) || | |||
| IsApply(prim::kPrimHistogramSummary)) { | |||
| std::string tag = GetValue<std::string>(GetValueNode(input(1))); | |||
| if (tag == "") { | |||
| MS_LOG(EXCEPTION) << "The tag name is null, should be valid string"; | |||
| } | |||
| std::string name; | |||
| if (IsApply(prim::kPrimScalarSummary)) { | |||
| name = tag + "[:Scalar]"; | |||
| } else if (IsApply(prim::kPrimImageSummary)) { | |||
| name = tag + "[:Image]"; | |||
| } else if (IsApply(prim::kPrimHistogramSummary)) { | |||
| name = tag + "[:Histogram]"; | |||
| } else { | |||
| name = tag + "[:Tensor]"; | |||
| } | |||
| fullname_with_scope_ = name; | |||
| } else { | |||
| // cnode input 0 should be primitive ptr | |||
| auto value_ptr = input(0)->cast<ValueNodePtr>(); | |||
| if (value_ptr == nullptr) { | |||
| MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << "."; | |||
| fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>()); | |||
| return fullname_with_scope_; | |||
| } | |||
| auto input_value = value_ptr->value(); | |||
| if (input_value == nullptr) { | |||
| MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr."; | |||
| fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>()); | |||
| return fullname_with_scope_; | |||
| } | |||
| PrimitivePtr prim = GetValue<PrimitivePtr>(input_value); | |||
| MS_EXCEPTION_IF_NULL(scope()); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| fullname_with_scope_ = | |||
| scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>()); | |||
| } | |||
| return fullname_with_scope_; | |||
| } | |||
| std::string ValueNode::ToString() const { | |||
| MS_EXCEPTION_IF_NULL(value_); | |||
| if (value_->isa<FuncGraph>()) { | |||
| @@ -173,10 +103,6 @@ std::string ValueNode::fullname_with_scope() { | |||
| return fullname_with_scope_; | |||
| } | |||
| void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); } | |||
| void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); } | |||
| void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); } | |||
| bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| @@ -52,6 +52,7 @@ class AbstractBase; | |||
| } // namespace abstract | |||
| using BaseShapePtr = std::shared_ptr<abstract::BaseShape>; | |||
| using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>; | |||
| using AbstractBasePtrList = std::vector<AbstractBasePtr>; | |||
| class ValueNode; | |||
| using ValueNodePtr = std::shared_ptr<ValueNode>; | |||
| @@ -78,6 +79,13 @@ using KernelInfoDevicePtr = std::shared_ptr<KernelInfoDevice>; | |||
| class AnfVisitor; | |||
| class ParamValue { | |||
| public: | |||
| ParamValue() = default; | |||
| virtual ~ParamValue() = default; | |||
| }; | |||
| using ParamValuePtr = std::shared_ptr<ParamValue>; | |||
| // AnfNode is the basic class of the IR definition derived from Base. | |||
| // Only two types of nodes are derived: CNode and ANode. | |||
| // Methods: | |||
| @@ -239,11 +247,11 @@ class ANode : public AnfNode { | |||
| // Parameter represents the parameter inputs of a function. They have no value. | |||
| // Attributes: | |||
| // default_param_: used to hold the inputting tensor of the model. | |||
| // default_param_value_: used to hold the inputting tensor of the model. | |||
| class Parameter : public ANode { | |||
| public: | |||
| explicit Parameter(const FuncGraphPtr &func_graph) | |||
| : ANode(func_graph), name_(""), has_default_(false), default_param_(py::none()), tensor_layout_(nullptr) {} | |||
| : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), tensor_layout_(nullptr) {} | |||
| ~Parameter() override = default; | |||
| MS_DECLARE_PARENT(Parameter, ANode); | |||
| @@ -254,12 +262,11 @@ class Parameter : public ANode { | |||
| std::string fullname_with_scope() override { return name(); }; | |||
| bool has_default() const { return has_default_; } | |||
| py::object default_param() { return default_param_; } | |||
| void set_default_param(const py::object &obj) { | |||
| default_param_ = obj; | |||
| void set_default_param(ParamValuePtr param) { | |||
| default_param_ = param; | |||
| has_default_ = true; | |||
| } | |||
| ParamValuePtr default_param() const { return default_param_; } | |||
| std::shared_ptr<parallel::TensorLayout> tensor_layout() const { return tensor_layout_; } | |||
| void set_tensor_layout(const std::shared_ptr<parallel::TensorLayout> &tensor_layout) { | |||
| @@ -280,7 +287,7 @@ class Parameter : public ANode { | |||
| private: | |||
| std::string name_; | |||
| bool has_default_; | |||
| py::object default_param_; | |||
| ParamValuePtr default_param_; | |||
| std::shared_ptr<parallel::TensorLayout> tensor_layout_; | |||
| }; | |||
| using ParameterPtr = std::shared_ptr<Parameter>; | |||
| @@ -0,0 +1,103 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ir/anf.h" | |||
| #include <algorithm> | |||
| #include <sstream> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "ir/visitor.h" | |||
| #include "pipeline/static_analysis/static_analysis.h" | |||
| #include "operator/ops.h" | |||
| #include "parallel/ops_info/ops_utils.h" | |||
| namespace mindspore { | |||
| // namespace to support intermediate representation definition | |||
| // Methods of AnfNode | |||
| TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); } | |||
| BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } | |||
| std::string AnfNode::ToString() const { | |||
| return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info()); | |||
| } | |||
| OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { | |||
| if (operator_info_ != nullptr) { | |||
| MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() | |||
| << ", using the new one: " << operator_info->name(); | |||
| auto old_ptr = operator_info_; | |||
| operator_info_ = operator_info; | |||
| return old_ptr; | |||
| } | |||
| operator_info_ = operator_info; | |||
| return nullptr; | |||
| } | |||
| std::string CNode::fullname_with_scope() { | |||
| // if full name is set, return its name immediately | |||
| if (!fullname_with_scope_.empty()) { | |||
| return fullname_with_scope_; | |||
| } | |||
| if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) || | |||
| IsApply(prim::kPrimHistogramSummary)) { | |||
| std::string tag = GetValue<std::string>(GetValueNode(input(1))); | |||
| if (tag == "") { | |||
| MS_LOG(EXCEPTION) << "The tag name is null, should be valid string"; | |||
| } | |||
| std::string name; | |||
| if (IsApply(prim::kPrimScalarSummary)) { | |||
| name = tag + "[:Scalar]"; | |||
| } else if (IsApply(prim::kPrimImageSummary)) { | |||
| name = tag + "[:Image]"; | |||
| } else if (IsApply(prim::kPrimHistogramSummary)) { | |||
| name = tag + "[:Histogram]"; | |||
| } else { | |||
| name = tag + "[:Tensor]"; | |||
| } | |||
| fullname_with_scope_ = name; | |||
| } else { | |||
| // cnode input 0 should be primitive ptr | |||
| auto value_ptr = input(0)->cast<ValueNodePtr>(); | |||
| if (value_ptr == nullptr) { | |||
| MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << "."; | |||
| fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>()); | |||
| return fullname_with_scope_; | |||
| } | |||
| auto input_value = value_ptr->value(); | |||
| if (input_value == nullptr) { | |||
| MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr."; | |||
| fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>()); | |||
| return fullname_with_scope_; | |||
| } | |||
| PrimitivePtr prim = GetValue<PrimitivePtr>(input_value); | |||
| MS_EXCEPTION_IF_NULL(scope()); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| fullname_with_scope_ = | |||
| scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>()); | |||
| } | |||
| return fullname_with_scope_; | |||
| } | |||
| void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); } | |||
| void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); } | |||
| void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); } | |||
| } // namespace mindspore | |||
| @@ -19,9 +19,6 @@ | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include "utils/log_adapter.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "pybind_api/export_flags.h" | |||
| namespace mindspore { | |||
| TypePtr Keyword::DeepCopy() const { | |||
| @@ -206,8 +203,6 @@ std::string Function::ToString() const { | |||
| return buffer.str(); | |||
| } | |||
| TypePtr TypeAnything::DeepCopy() const { return kAnyType; } | |||
| TypePtr JTagged::DeepCopy() const { | |||
| MS_EXCEPTION_IF_NULL(subtype_); | |||
| if (IsGeneric()) { | |||
| @@ -247,460 +242,4 @@ std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> proble | |||
| os << problem->ToString(); | |||
| return os; | |||
| } | |||
| std::size_t TypeHasher::operator()(TypePtr const &type) const { | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| std::size_t hash = std::hash<size_t>()(type->type_id()); | |||
| return hash; | |||
| } | |||
| std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const { | |||
| std::size_t hash_sum = 0; | |||
| for (auto &type : type_list) { | |||
| auto type_id = static_cast<std::size_t>(type->type_id()); | |||
| hash_sum = hash_combine(hash_sum, type_id); | |||
| } | |||
| return hash_sum; | |||
| } | |||
| bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const { | |||
| MS_EXCEPTION_IF_NULL(t1); | |||
| MS_EXCEPTION_IF_NULL(t2); | |||
| return t1->type_id() == t2->type_id(); | |||
| } | |||
| bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const { | |||
| if (lhs.size() != rhs.size()) { | |||
| return false; | |||
| } | |||
| std::size_t size = lhs.size(); | |||
| for (std::size_t i = 0; i < size; ++i) { | |||
| MS_EXCEPTION_IF_NULL(lhs[i]); | |||
| MS_EXCEPTION_IF_NULL(rhs[i]); | |||
| if (*lhs[i] != *rhs[i]) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| TypePtr TypeIdToType(TypeId id) { | |||
| switch (id) { | |||
| case kNumberTypeFloat16: | |||
| return kFloat16; | |||
| case kNumberTypeFloat: | |||
| case kNumberTypeFloat32: | |||
| return kFloat32; | |||
| case kNumberTypeFloat64: | |||
| return kFloat64; | |||
| case kNumberTypeInt8: | |||
| return kInt8; | |||
| case kNumberTypeInt16: | |||
| return kInt16; | |||
| case kNumberTypeInt32: | |||
| return kInt32; | |||
| case kNumberTypeInt64: | |||
| return kInt64; | |||
| case kNumberTypeUInt8: | |||
| return kUInt8; | |||
| case kNumberTypeUInt16: | |||
| return kUInt16; | |||
| case kNumberTypeUInt32: | |||
| return kUInt32; | |||
| case kNumberTypeUInt64: | |||
| return kUInt64; | |||
| case kNumberTypeBool: | |||
| return kBool; | |||
| case kMetaTypeExternal: | |||
| return kTypeExternal; | |||
| case kMetaTypeAnything: | |||
| return kAnyType; | |||
| case kMetaTypeNone: | |||
| return kTypeNone; | |||
| case kObjectTypeEnvType: | |||
| return kTypeEnv; | |||
| case kObjectTypeRefKey: | |||
| return kRefKeyType; | |||
| case kObjectTypeRef: | |||
| return kRefType; | |||
| case kTypeUnknown: | |||
| return kTypeNone; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Not support the type: " << id; | |||
| } | |||
| } | |||
| namespace { | |||
| template <typename T> | |||
| TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == num_type_name) { | |||
| type = std::make_shared<T>(); | |||
| } else { | |||
| try { | |||
| if (num_type_name.size() >= type_name.size()) { | |||
| MS_LOG(EXCEPTION) << "Convert type is error, type_name(" << type_name << "), num_type_name(" << num_type_name | |||
| << ")"; | |||
| } | |||
| auto bits = std::stoi(type_name.substr(num_type_name.size())); | |||
| type = std::make_shared<T>(bits); | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << num_type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| return type; | |||
| } | |||
| std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) { | |||
| std::vector<TypePtr> types; | |||
| if (type_names.length() == 0) { | |||
| return types; | |||
| } | |||
| std::string::size_type start = 0; | |||
| std::string::size_type end = type_names.find_first_of(','); | |||
| while (end != std::string::npos) { | |||
| types.push_back(StringToType(type_names.substr(start, end))); | |||
| // Skip ',' to find the next element. | |||
| start = end + 1; | |||
| end = type_names.find_first_of(',', start); | |||
| } | |||
| if (start >= type_names.size()) { | |||
| MS_LOG(EXCEPTION) << "Type name is empty string."; | |||
| } | |||
| types.push_back(StringToType(type_names.substr(start))); | |||
| return types; | |||
| } | |||
| TypePtr TensorStrToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == "Tensor") { | |||
| type = std::make_shared<TensorType>(); | |||
| } else { | |||
| try { | |||
| auto start = type_name.find_first_of('[') + 1; | |||
| auto end = type_name.find_last_of(']'); | |||
| if (start >= type_name.size()) { | |||
| return nullptr; | |||
| } | |||
| auto element_str = type_name.substr(start, end - start); | |||
| auto element_type = StringToType(element_str); | |||
| if (element_type == nullptr) { | |||
| return nullptr; | |||
| } | |||
| type = std::make_shared<TensorType>(element_type); | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| return type; | |||
| } | |||
| TypePtr ListStrToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == "List") { | |||
| type = std::make_shared<List>(); | |||
| } else { | |||
| try { | |||
| auto start = type_name.find_first_of('[') + 1; | |||
| auto end = type_name.find_last_of(']'); | |||
| if (start >= type_name.size()) { | |||
| return nullptr; | |||
| } | |||
| std::string element_strs = type_name.substr(start, end - start); | |||
| std::vector<TypePtr> element_types = StringToVectorOfType(element_strs); | |||
| bool wrong = | |||
| std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); | |||
| if (wrong) { | |||
| return nullptr; | |||
| } | |||
| type = std::make_shared<List>(element_types); | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| return type; | |||
| } | |||
| TypePtr TupleStrToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == "Tuple") { | |||
| type = std::make_shared<Tuple>(); | |||
| } else { | |||
| try { | |||
| size_t start = type_name.find_first_of('[') + 1; | |||
| size_t end = type_name.find_last_of(']'); | |||
| if (start >= type_name.size()) { | |||
| return nullptr; | |||
| } | |||
| std::string element_strs = type_name.substr(start, end - start); | |||
| std::vector<TypePtr> element_types = StringToVectorOfType(element_strs); | |||
| bool wrong = | |||
| std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); | |||
| if (wrong) { | |||
| return nullptr; | |||
| } | |||
| type = std::make_shared<Tuple>(element_types); | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| return type; | |||
| } | |||
| TypePtr FunctionStrToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == "Function") { | |||
| type = std::make_shared<Function>(); | |||
| } else { | |||
| try { | |||
| // format: [(para1, para2, para3, ...) retval] | |||
| size_t start = type_name.find_first_of('[') + 1; | |||
| size_t end = type_name.find_last_of(']'); | |||
| if (start >= type_name.size()) { | |||
| return nullptr; | |||
| } | |||
| std::string str_all = type_name.substr(start, end - start); | |||
| size_t start_a = str_all.find_first_of('(') + 1; | |||
| size_t end_a = str_all.find_last_of(')'); | |||
| if (start_a >= str_all.size()) { | |||
| return nullptr; | |||
| } | |||
| std::string str_args = str_all.substr(start_a, end_a - start_a); | |||
| // bypass " " between ")" and retval | |||
| start = end_a + 2; | |||
| if (start >= str_all.size()) { | |||
| return nullptr; | |||
| } | |||
| std::string str_retval = str_all.substr(start); | |||
| std::vector<TypePtr> args_type = StringToVectorOfType(str_args); | |||
| TypePtr retval = StringToType(str_retval); | |||
| bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; }); | |||
| if (retval == nullptr || wrong) { | |||
| return nullptr; | |||
| } | |||
| type = std::make_shared<Function>(args_type, retval); | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| return type; | |||
| } | |||
| } // namespace | |||
| TypePtr StringToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name.compare("None") == 0) { | |||
| type = std::make_shared<TypeNone>(); | |||
| } else if (type_name.compare("Ellipsis") == 0) { | |||
| type = std::make_shared<Ellipsis>(); | |||
| } else if (type_name.compare("TypeType") == 0) { | |||
| type = std::make_shared<TypeType>(); | |||
| } else if (type_name.compare("SymbolicKeyType") == 0) { | |||
| type = std::make_shared<SymbolicKeyType>(); | |||
| } else if (type_name.compare("RefKeyType") == 0) { | |||
| type = std::make_shared<RefKeyType>(); | |||
| } else if (type_name.compare("EnvType") == 0) { | |||
| type = std::make_shared<EnvType>(); | |||
| } else if (type_name.compare("Number") == 0) { | |||
| type = std::make_shared<Number>(); | |||
| } else if (type_name.compare("Bool") == 0) { | |||
| type = std::make_shared<Bool>(); | |||
| } else if (type_name.compare(0, strlen("Int"), "Int") == 0) { | |||
| type = StringToNumberType<Int>(type_name, "Int"); | |||
| } else if (type_name.compare(0, strlen("UInt"), "UInt") == 0) { | |||
| type = StringToNumberType<UInt>(type_name, "UInt"); | |||
| } else if (type_name.compare(0, strlen("Float"), "Float") == 0) { | |||
| type = StringToNumberType<Float>(type_name, "Float"); | |||
| } else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) { | |||
| type = TensorStrToType(type_name); | |||
| } else if (type_name.compare(0, strlen("List"), "List") == 0) { | |||
| type = ListStrToType(type_name); | |||
| } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { | |||
| type = TupleStrToType(type_name); | |||
| } else if (type_name.compare("Slice") == 0) { | |||
| type = std::make_shared<Slice>(); | |||
| } else if (type_name.compare("Dictionary") == 0) { | |||
| type = std::make_shared<Dictionary>(); | |||
| } else if (type_name.compare("String") == 0) { | |||
| type = std::make_shared<String>(); | |||
| } else if (type_name.compare("Problem") == 0) { | |||
| type = std::make_shared<Problem>(); | |||
| } else if (type_name.compare(0, strlen("Function"), "Function") == 0) { | |||
| type = FunctionStrToType(type_name); | |||
| } else { | |||
| // - unsupported to convert | |||
| // Class | |||
| // SymbolicType | |||
| // JTagged | |||
| // Anything | |||
| // External | |||
| // Problem | |||
| MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!"; | |||
| } | |||
| return type; | |||
| } | |||
| bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { | |||
| if (x == nullptr || base_type == nullptr) { | |||
| MS_LOG(ERROR) << "Type is nullptr."; | |||
| return false; | |||
| } | |||
| if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) { | |||
| return false; | |||
| } else if (!(base_type->IsGeneric())) { | |||
| return *(base_type) == *(x); | |||
| } else if (base_type->type_id() == x->type_id()) { | |||
| return true; | |||
| } else if (base_type->type_id() == x->generic_type_id()) { | |||
| return true; | |||
| } else if (base_type->type_id() == x->object_type()) { | |||
| return true; | |||
| } else if (base_type->type_id() == x->meta_type()) { | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool IsSubType(TypePtr const &t1, TypePtr const &t2) { | |||
| MS_EXCEPTION_IF_NULL(t1); | |||
| if (t1->type_id() == kTypeUnknown) { | |||
| return false; | |||
| } else if (t2 != nullptr) { | |||
| return IsIdentidityOrSubclass(t1, t2); | |||
| } else { | |||
| return true; | |||
| } | |||
| } | |||
| REGISTER_PYBIND_DEFINE( | |||
| typing, ([](py::module *const m) { | |||
| auto m_sub = m->def_submodule("typing", "submodule for dtype"); | |||
| py::enum_<TypeId>(m_sub, "TypeId"); | |||
| (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); | |||
| (void)m_sub.def("load_type", &TypeIdToType, "load type"); | |||
| (void)m_sub.def( | |||
| "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); | |||
| (void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type") | |||
| .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) | |||
| .def("__eq__", | |||
| [](const TypePtr &t1, const TypePtr &t2) { | |||
| if (t1 != nullptr && t2 != nullptr) { | |||
| return *t1 == *t2; | |||
| } | |||
| return false; | |||
| }) | |||
| .def("__hash__", &Type::hash) | |||
| .def("__str__", &Type::ToString) | |||
| .def("__repr__", &Type::ReprString) | |||
| .def("__deepcopy__", [](const TypePtr &t, py::dict) { | |||
| if (t == nullptr) { | |||
| return static_cast<TypePtr>(nullptr); | |||
| } | |||
| return t->DeepCopy(); | |||
| }); | |||
| (void)py::class_<Number, Type, std::shared_ptr<Number>>(m_sub, "Number").def(py::init()); | |||
| (void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool") | |||
| .def(py::init()) | |||
| .def(py::pickle( | |||
| [](const Bool &) { // __getstate__ | |||
| return py::make_tuple(); | |||
| }, | |||
| [](const py::tuple &) { // __setstate__ | |||
| return std::make_shared<Bool>(); | |||
| })); | |||
| (void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int") | |||
| .def(py::init()) | |||
| .def(py::init<int>(), py::arg("nbits")) | |||
| .def(py::pickle( | |||
| [](const Int &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(py::int_(t.nbits())); | |||
| }, | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| /* Create a new C++ instance */ | |||
| Int data(t[0].cast<py::int_>()); | |||
| return data; | |||
| })); | |||
| (void)py::class_<UInt, Type, std::shared_ptr<UInt>>(m_sub, "UInt") | |||
| .def(py::init()) | |||
| .def(py::init<int>(), py::arg("nbits")) | |||
| .def(py::pickle( | |||
| [](const UInt &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(py::int_(t.nbits())); | |||
| }, | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| /* Create a new C++ instance */ | |||
| UInt data(t[0].cast<py::int_>()); | |||
| return data; | |||
| })); | |||
| (void)py::class_<Float, Type, std::shared_ptr<Float>>(m_sub, "Float") | |||
| .def(py::init()) | |||
| .def(py::init<int>(), py::arg("nbits")) | |||
| .def(py::pickle( | |||
| [](const Float &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(py::int_(t.nbits())); | |||
| }, | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| /* Create a new C++ instance */ | |||
| Float data(t[0].cast<py::int_>()); | |||
| return data; | |||
| })); | |||
| (void)py::class_<List, Type, std::shared_ptr<List>>(m_sub, "List") | |||
| .def(py::init()) | |||
| .def(py::init<std::vector<TypePtr>>(), py::arg("elements")); | |||
| (void)py::class_<Tuple, Type, std::shared_ptr<Tuple>>(m_sub, "Tuple") | |||
| .def(py::init()) | |||
| .def(py::init<std::vector<TypePtr>>(), py::arg("elements")); | |||
| (void)py::class_<TensorType, Type, std::shared_ptr<TensorType>>(m_sub, "TensorType") | |||
| .def(py::init()) | |||
| .def(py::init<TypePtr>(), py::arg("element")) | |||
| .def("element_type", &TensorType::element) | |||
| .def(py::pickle( | |||
| [](const TensorType &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id()))); | |||
| }, | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| /* Create a new C++ instance */ | |||
| TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>())))); | |||
| return data; | |||
| })); | |||
| (void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function") | |||
| .def(py::init()) | |||
| .def(py::init<std::vector<TypePtr>, TypePtr>(), py::arg("args"), py::arg("retval")); | |||
| (void)py::class_<Class, Type, std::shared_ptr<Class>>(m_sub, "Class").def(py::init()); | |||
| (void)py::class_<SymbolicKeyType, Type, std::shared_ptr<SymbolicKeyType>>(m_sub, "SymbolicKeyType").def(py::init()); | |||
| (void)py::class_<EnvType, Type, std::shared_ptr<EnvType>>(m_sub, "EnvType").def(py::init()); | |||
| (void)py::class_<TypeNone, Type, std::shared_ptr<TypeNone>>(m_sub, "TypeNone").def(py::init()); | |||
| (void)py::class_<TypeType, Type, std::shared_ptr<TypeType>>(m_sub, "TypeType").def(py::init()); | |||
| (void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init()); | |||
| (void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init()); | |||
| (void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); | |||
| (void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init()); | |||
| })); | |||
| const TypePtr kTypeExternal = std::make_shared<External>(); | |||
| const TypePtr kTypeEnv = std::make_shared<EnvType>(); | |||
| const TypePtr kTypeType = std::make_shared<TypeType>(); | |||
| const TypePtr kTensorType = std::make_shared<TensorType>(); | |||
| const TypePtr kString = std::make_shared<String>(); | |||
| } // namespace mindspore | |||
| @@ -19,9 +19,6 @@ | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include "utils/log_adapter.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "pybind_api/export_flags.h" | |||
| namespace mindspore { | |||
| static std::string DumpTypeVector(const std::vector<TypePtr> &elements, bool is_dumptext) { | |||
| @@ -19,9 +19,6 @@ | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include "utils/log_adapter.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "pybind_api/export_flags.h" | |||
| namespace mindspore { | |||
| bool Number::operator==(const Type &other) const { | |||
| @@ -19,9 +19,6 @@ | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include "utils/log_adapter.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "pybind_api/export_flags.h" | |||
| namespace mindspore { | |||
| TypePtr RefType::DeepCopy() const { | |||
| @@ -21,9 +21,8 @@ | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include "utils/log_adapter.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "pybind_api/export_flags.h" | |||
| #include "ir/dtype/number.h" | |||
| #include "utils/convert_utils.h" | |||
| namespace mindspore { | |||
| TypeId IntBitsToTypeId(const int nbits) { | |||
| @@ -227,11 +226,6 @@ bool Type::operator==(const Value &other) const { | |||
| } | |||
| } | |||
| abstract::AbstractBasePtr Type::ToAbstract() { | |||
| abstract::AbstractBasePtr ptr = std::make_shared<abstract::AbstractType>(shared_from_base<Type>()); | |||
| return ptr; | |||
| } | |||
| std::ostream &operator<<(std::ostream &os, const Type &type) { | |||
| os << type.ToString(); | |||
| return os; | |||
| @@ -0,0 +1,25 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ir/dtype/type.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| namespace mindspore { | |||
| abstract::AbstractBasePtr Type::ToAbstract() { | |||
| auto ptr = std::make_shared<abstract::AbstractType>(shared_from_base<Type>()); | |||
| return ptr; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,484 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ir/dtype.h" | |||
| #include <string> | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include "utils/log_adapter.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "pybind_api/export_flags.h" | |||
| namespace mindspore { | |||
| TypePtr TypeAnything::DeepCopy() const { return kAnyType; } | |||
| std::size_t TypeHasher::operator()(TypePtr const &type) const { | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| std::size_t hash = std::hash<size_t>()(type->type_id()); | |||
| return hash; | |||
| } | |||
| std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const { | |||
| std::size_t hash_sum = 0; | |||
| for (auto &type : type_list) { | |||
| auto type_id = static_cast<std::size_t>(type->type_id()); | |||
| hash_sum = hash_combine(hash_sum, type_id); | |||
| } | |||
| return hash_sum; | |||
| } | |||
| bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const { | |||
| MS_EXCEPTION_IF_NULL(t1); | |||
| MS_EXCEPTION_IF_NULL(t2); | |||
| return t1->type_id() == t2->type_id(); | |||
| } | |||
| bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const { | |||
| if (lhs.size() != rhs.size()) { | |||
| return false; | |||
| } | |||
| std::size_t size = lhs.size(); | |||
| for (std::size_t i = 0; i < size; ++i) { | |||
| MS_EXCEPTION_IF_NULL(lhs[i]); | |||
| MS_EXCEPTION_IF_NULL(rhs[i]); | |||
| if (*lhs[i] != *rhs[i]) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| TypePtr TypeIdToType(TypeId id) { | |||
| switch (id) { | |||
| case kNumberTypeFloat16: | |||
| return kFloat16; | |||
| case kNumberTypeFloat: | |||
| case kNumberTypeFloat32: | |||
| return kFloat32; | |||
| case kNumberTypeFloat64: | |||
| return kFloat64; | |||
| case kNumberTypeInt8: | |||
| return kInt8; | |||
| case kNumberTypeInt16: | |||
| return kInt16; | |||
| case kNumberTypeInt32: | |||
| return kInt32; | |||
| case kNumberTypeInt64: | |||
| return kInt64; | |||
| case kNumberTypeUInt8: | |||
| return kUInt8; | |||
| case kNumberTypeUInt16: | |||
| return kUInt16; | |||
| case kNumberTypeUInt32: | |||
| return kUInt32; | |||
| case kNumberTypeUInt64: | |||
| return kUInt64; | |||
| case kNumberTypeBool: | |||
| return kBool; | |||
| case kMetaTypeExternal: | |||
| return kTypeExternal; | |||
| case kMetaTypeAnything: | |||
| return kAnyType; | |||
| case kMetaTypeNone: | |||
| return kTypeNone; | |||
| case kObjectTypeEnvType: | |||
| return kTypeEnv; | |||
| case kObjectTypeRefKey: | |||
| return kRefKeyType; | |||
| case kObjectTypeRef: | |||
| return kRefType; | |||
| case kTypeUnknown: | |||
| return kTypeNone; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Not support the type: " << id; | |||
| } | |||
| } | |||
| namespace { | |||
| template <typename T> | |||
| TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == num_type_name) { | |||
| type = std::make_shared<T>(); | |||
| } else { | |||
| try { | |||
| if (num_type_name.size() >= type_name.size()) { | |||
| MS_LOG(EXCEPTION) << "Convert type is error, type_name(" << type_name << "), num_type_name(" << num_type_name | |||
| << ")"; | |||
| } | |||
| auto bits = std::stoi(type_name.substr(num_type_name.size())); | |||
| type = std::make_shared<T>(bits); | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << num_type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| return type; | |||
| } | |||
| std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) { | |||
| std::vector<TypePtr> types; | |||
| if (type_names.length() == 0) { | |||
| return types; | |||
| } | |||
| std::string::size_type start = 0; | |||
| std::string::size_type end = type_names.find_first_of(','); | |||
| while (end != std::string::npos) { | |||
| types.push_back(StringToType(type_names.substr(start, end))); | |||
| // Skip ',' to find the next element. | |||
| start = end + 1; | |||
| end = type_names.find_first_of(',', start); | |||
| } | |||
| if (start >= type_names.size()) { | |||
| MS_LOG(EXCEPTION) << "Type name is empty string."; | |||
| } | |||
| types.push_back(StringToType(type_names.substr(start))); | |||
| return types; | |||
| } | |||
| TypePtr TensorStrToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == "Tensor") { | |||
| type = std::make_shared<TensorType>(); | |||
| } else { | |||
| try { | |||
| auto start = type_name.find_first_of('[') + 1; | |||
| auto end = type_name.find_last_of(']'); | |||
| if (start >= type_name.size()) { | |||
| return nullptr; | |||
| } | |||
| auto element_str = type_name.substr(start, end - start); | |||
| auto element_type = StringToType(element_str); | |||
| if (element_type == nullptr) { | |||
| return nullptr; | |||
| } | |||
| type = std::make_shared<TensorType>(element_type); | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| return type; | |||
| } | |||
| TypePtr ListStrToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == "List") { | |||
| type = std::make_shared<List>(); | |||
| } else { | |||
| try { | |||
| auto start = type_name.find_first_of('[') + 1; | |||
| auto end = type_name.find_last_of(']'); | |||
| if (start >= type_name.size()) { | |||
| return nullptr; | |||
| } | |||
| std::string element_strs = type_name.substr(start, end - start); | |||
| std::vector<TypePtr> element_types = StringToVectorOfType(element_strs); | |||
| bool wrong = | |||
| std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); | |||
| if (wrong) { | |||
| return nullptr; | |||
| } | |||
| type = std::make_shared<List>(element_types); | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| return type; | |||
| } | |||
| TypePtr TupleStrToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == "Tuple") { | |||
| type = std::make_shared<Tuple>(); | |||
| } else { | |||
| try { | |||
| size_t start = type_name.find_first_of('[') + 1; | |||
| size_t end = type_name.find_last_of(']'); | |||
| if (start >= type_name.size()) { | |||
| return nullptr; | |||
| } | |||
| std::string element_strs = type_name.substr(start, end - start); | |||
| std::vector<TypePtr> element_types = StringToVectorOfType(element_strs); | |||
| bool wrong = | |||
| std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); | |||
| if (wrong) { | |||
| return nullptr; | |||
| } | |||
| type = std::make_shared<Tuple>(element_types); | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| return type; | |||
| } | |||
| TypePtr FunctionStrToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name == "Function") { | |||
| type = std::make_shared<Function>(); | |||
| } else { | |||
| try { | |||
| // format: [(para1, para2, para3, ...) retval] | |||
| size_t start = type_name.find_first_of('[') + 1; | |||
| size_t end = type_name.find_last_of(']'); | |||
| if (start >= type_name.size()) { | |||
| return nullptr; | |||
| } | |||
| std::string str_all = type_name.substr(start, end - start); | |||
| size_t start_a = str_all.find_first_of('(') + 1; | |||
| size_t end_a = str_all.find_last_of(')'); | |||
| if (start_a >= str_all.size()) { | |||
| return nullptr; | |||
| } | |||
| std::string str_args = str_all.substr(start_a, end_a - start_a); | |||
| // bypass " " between ")" and retval | |||
| start = end_a + 2; | |||
| if (start >= str_all.size()) { | |||
| return nullptr; | |||
| } | |||
| std::string str_retval = str_all.substr(start); | |||
| std::vector<TypePtr> args_type = StringToVectorOfType(str_args); | |||
| TypePtr retval = StringToType(str_retval); | |||
| bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; }); | |||
| if (retval == nullptr || wrong) { | |||
| return nullptr; | |||
| } | |||
| type = std::make_shared<Function>(args_type, retval); | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); | |||
| } | |||
| } | |||
| return type; | |||
| } | |||
| } // namespace | |||
| TypePtr StringToType(const std::string &type_name) { | |||
| TypePtr type = nullptr; | |||
| if (type_name.compare("None") == 0) { | |||
| type = std::make_shared<TypeNone>(); | |||
| } else if (type_name.compare("Ellipsis") == 0) { | |||
| type = std::make_shared<Ellipsis>(); | |||
| } else if (type_name.compare("TypeType") == 0) { | |||
| type = std::make_shared<TypeType>(); | |||
| } else if (type_name.compare("SymbolicKeyType") == 0) { | |||
| type = std::make_shared<SymbolicKeyType>(); | |||
| } else if (type_name.compare("RefKeyType") == 0) { | |||
| type = std::make_shared<RefKeyType>(); | |||
| } else if (type_name.compare("EnvType") == 0) { | |||
| type = std::make_shared<EnvType>(); | |||
| } else if (type_name.compare("Number") == 0) { | |||
| type = std::make_shared<Number>(); | |||
| } else if (type_name.compare("Bool") == 0) { | |||
| type = std::make_shared<Bool>(); | |||
| } else if (type_name.compare(0, strlen("Int"), "Int") == 0) { | |||
| type = StringToNumberType<Int>(type_name, "Int"); | |||
| } else if (type_name.compare(0, strlen("UInt"), "UInt") == 0) { | |||
| type = StringToNumberType<UInt>(type_name, "UInt"); | |||
| } else if (type_name.compare(0, strlen("Float"), "Float") == 0) { | |||
| type = StringToNumberType<Float>(type_name, "Float"); | |||
| } else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) { | |||
| type = TensorStrToType(type_name); | |||
| } else if (type_name.compare(0, strlen("List"), "List") == 0) { | |||
| type = ListStrToType(type_name); | |||
| } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { | |||
| type = TupleStrToType(type_name); | |||
| } else if (type_name.compare("Slice") == 0) { | |||
| type = std::make_shared<Slice>(); | |||
| } else if (type_name.compare("Dictionary") == 0) { | |||
| type = std::make_shared<Dictionary>(); | |||
| } else if (type_name.compare("String") == 0) { | |||
| type = std::make_shared<String>(); | |||
| } else if (type_name.compare("Problem") == 0) { | |||
| type = std::make_shared<Problem>(); | |||
| } else if (type_name.compare(0, strlen("Function"), "Function") == 0) { | |||
| type = FunctionStrToType(type_name); | |||
| } else { | |||
| // - unsupported to convert | |||
| // Class | |||
| // SymbolicType | |||
| // JTagged | |||
| // Anything | |||
| // External | |||
| // Problem | |||
| MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!"; | |||
| } | |||
| return type; | |||
| } | |||
| bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { | |||
| if (x == nullptr || base_type == nullptr) { | |||
| MS_LOG(ERROR) << "Type is nullptr."; | |||
| return false; | |||
| } | |||
| if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) { | |||
| return false; | |||
| } else if (!(base_type->IsGeneric())) { | |||
| return *(base_type) == *(x); | |||
| } else if (base_type->type_id() == x->type_id()) { | |||
| return true; | |||
| } else if (base_type->type_id() == x->generic_type_id()) { | |||
| return true; | |||
| } else if (base_type->type_id() == x->object_type()) { | |||
| return true; | |||
| } else if (base_type->type_id() == x->meta_type()) { | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool IsSubType(TypePtr const &t1, TypePtr const &t2) { | |||
| MS_EXCEPTION_IF_NULL(t1); | |||
| if (t1->type_id() == kTypeUnknown) { | |||
| return false; | |||
| } else if (t2 != nullptr) { | |||
| return IsIdentidityOrSubclass(t1, t2); | |||
| } else { | |||
| return true; | |||
| } | |||
| } | |||
| REGISTER_PYBIND_DEFINE( | |||
| typing, ([](py::module *const m) { | |||
| auto m_sub = m->def_submodule("typing", "submodule for dtype"); | |||
| py::enum_<TypeId>(m_sub, "TypeId"); | |||
| (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); | |||
| (void)m_sub.def("load_type", &TypeIdToType, "load type"); | |||
| (void)m_sub.def( | |||
| "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); | |||
| (void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type") | |||
| .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) | |||
| .def("__eq__", | |||
| [](const TypePtr &t1, const TypePtr &t2) { | |||
| if (t1 != nullptr && t2 != nullptr) { | |||
| return *t1 == *t2; | |||
| } | |||
| return false; | |||
| }) | |||
| .def("__hash__", &Type::hash) | |||
| .def("__str__", &Type::ToString) | |||
| .def("__repr__", &Type::ReprString) | |||
| .def("__deepcopy__", [](const TypePtr &t, py::dict) { | |||
| if (t == nullptr) { | |||
| return static_cast<TypePtr>(nullptr); | |||
| } | |||
| return t->DeepCopy(); | |||
| }); | |||
| (void)py::class_<Number, Type, std::shared_ptr<Number>>(m_sub, "Number").def(py::init()); | |||
| (void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool") | |||
| .def(py::init()) | |||
| .def(py::pickle( | |||
| [](const Bool &) { // __getstate__ | |||
| return py::make_tuple(); | |||
| }, | |||
| [](const py::tuple &) { // __setstate__ | |||
| return std::make_shared<Bool>(); | |||
| })); | |||
| (void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int") | |||
| .def(py::init()) | |||
| .def(py::init<int>(), py::arg("nbits")) | |||
| .def(py::pickle( | |||
| [](const Int &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(py::int_(t.nbits())); | |||
| }, | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| /* Create a new C++ instance */ | |||
| Int data(t[0].cast<py::int_>()); | |||
| return data; | |||
| })); | |||
| (void)py::class_<UInt, Type, std::shared_ptr<UInt>>(m_sub, "UInt") | |||
| .def(py::init()) | |||
| .def(py::init<int>(), py::arg("nbits")) | |||
| .def(py::pickle( | |||
| [](const UInt &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(py::int_(t.nbits())); | |||
| }, | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| /* Create a new C++ instance */ | |||
| UInt data(t[0].cast<py::int_>()); | |||
| return data; | |||
| })); | |||
| (void)py::class_<Float, Type, std::shared_ptr<Float>>(m_sub, "Float") | |||
| .def(py::init()) | |||
| .def(py::init<int>(), py::arg("nbits")) | |||
| .def(py::pickle( | |||
| [](const Float &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(py::int_(t.nbits())); | |||
| }, | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| /* Create a new C++ instance */ | |||
| Float data(t[0].cast<py::int_>()); | |||
| return data; | |||
| })); | |||
| (void)py::class_<List, Type, std::shared_ptr<List>>(m_sub, "List") | |||
| .def(py::init()) | |||
| .def(py::init<std::vector<TypePtr>>(), py::arg("elements")); | |||
| (void)py::class_<Tuple, Type, std::shared_ptr<Tuple>>(m_sub, "Tuple") | |||
| .def(py::init()) | |||
| .def(py::init<std::vector<TypePtr>>(), py::arg("elements")); | |||
| (void)py::class_<TensorType, Type, std::shared_ptr<TensorType>>(m_sub, "TensorType") | |||
| .def(py::init()) | |||
| .def(py::init<TypePtr>(), py::arg("element")) | |||
| .def("element_type", &TensorType::element) | |||
| .def(py::pickle( | |||
| [](const TensorType &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id()))); | |||
| }, | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| /* Create a new C++ instance */ | |||
| TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>())))); | |||
| return data; | |||
| })); | |||
| (void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function") | |||
| .def(py::init()) | |||
| .def(py::init<std::vector<TypePtr>, TypePtr>(), py::arg("args"), py::arg("retval")); | |||
| (void)py::class_<Class, Type, std::shared_ptr<Class>>(m_sub, "Class").def(py::init()); | |||
| (void)py::class_<SymbolicKeyType, Type, std::shared_ptr<SymbolicKeyType>>(m_sub, "SymbolicKeyType").def(py::init()); | |||
| (void)py::class_<EnvType, Type, std::shared_ptr<EnvType>>(m_sub, "EnvType").def(py::init()); | |||
| (void)py::class_<TypeNone, Type, std::shared_ptr<TypeNone>>(m_sub, "TypeNone").def(py::init()); | |||
| (void)py::class_<TypeType, Type, std::shared_ptr<TypeType>>(m_sub, "TypeType").def(py::init()); | |||
| (void)py::class_<String, Type, std::shared_ptr<String>>(m_sub, "String").def(py::init()); | |||
| (void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init()); | |||
| (void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); | |||
| (void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init()); | |||
| })); | |||
| const TypePtr kTypeExternal = std::make_shared<External>(); | |||
| const TypePtr kTypeEnv = std::make_shared<EnvType>(); | |||
| const TypePtr kTypeType = std::make_shared<TypeType>(); | |||
| const TypePtr kTensorType = std::make_shared<TensorType>(); | |||
| const TypePtr kString = std::make_shared<String>(); | |||
| } // namespace mindspore | |||
| @@ -19,6 +19,7 @@ | |||
| #include <algorithm> | |||
| #include "ir/manager.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "operator/ops.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/profile.h" | |||
| @@ -69,7 +70,9 @@ void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, | |||
| new_param->set_abstract(old_param->abstract()); | |||
| new_param->set_name(old_param->name()); | |||
| if (old_param->has_default()) { | |||
| new_param->set_default_param(old_param->default_param()); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(old_param->default_param()); | |||
| auto param_value_new = std::make_shared<ParamValuePy>(param_value->value()); | |||
| new_param->set_default_param(param_value_new); | |||
| } | |||
| ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); | |||
| new_param->set_scope(scope); | |||
| @@ -248,7 +251,9 @@ void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) { | |||
| if (node->isa<Parameter>()) { | |||
| ParameterPtr old_param = dyn_cast<Parameter>(node); | |||
| if (old_param->has_default()) { | |||
| param->set_default_param(old_param->default_param()); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(old_param->default_param()); | |||
| auto param_value_new = std::make_shared<ParamValuePy>(param_value->value()); | |||
| param->set_default_param(param_value_new); | |||
| } | |||
| param->set_name(old_param->name()); | |||
| } | |||
| @@ -28,6 +28,7 @@ | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/manager.h" | |||
| namespace mindspore { | |||
| class Cloner; | |||
| @@ -31,6 +31,7 @@ | |||
| #include "ir/dtype.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/signature.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| namespace py = pybind11; | |||
| @@ -21,7 +21,6 @@ | |||
| #include <memory> | |||
| #include <functional> | |||
| #include "ir/base.h" | |||
| #include "ir/anf.h" | |||
| namespace mindspore { | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_ | |||
| #define MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_ | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| namespace mindspore { | |||
| class ParamValueMinnie : public ParamValue { | |||
| public: | |||
| ParamValueMinnie() : tensor_addr_(nullptr), tensor_size_(0) {} | |||
| virtual ~ParamValueMinnie() = default; | |||
| size_t tensor_size() const { return tensor_size_; } | |||
| void set_tensor_size(size_t size) { tensor_size_ = size; } | |||
| void *tensor_addr() const { return tensor_addr_; } | |||
| void set_tensor_addr(void *addr) { tensor_addr_ = addr; } | |||
| private: | |||
| void *tensor_addr_; | |||
| size_t tensor_size_; | |||
| }; | |||
| using ParamValueMinniePtr = std::shared_ptr<ParamValueMinnie>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_ | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_ | |||
| #define MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_ | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "pybind11/pybind11.h" | |||
| namespace mindspore { | |||
| namespace py = pybind11; | |||
| class ParamValuePy : public ParamValue { | |||
| public: | |||
| ParamValuePy() : value_(py::none()) {} | |||
| explicit ParamValuePy(py::object value) : value_(value) {} | |||
| virtual ~ParamValuePy() = default; | |||
| py::object value() { return value_; } | |||
| void set_value(const py::object &obj) { value_ = obj; } | |||
| private: | |||
| py::object value_; | |||
| }; | |||
| using ParamValuePyPtr = std::shared_ptr<ParamValuePy>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_ | |||
| @@ -17,20 +17,23 @@ | |||
| #ifndef MINDSPORE_CCSRC_IR_SCALAR_H_ | |||
| #define MINDSPORE_CCSRC_IR_SCALAR_H_ | |||
| namespace mindspore { | |||
| /* namespace to support inference engine */ | |||
| #include <type_traits> | |||
| #include <algorithm> | |||
| #include <cmath> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <sstream> | |||
| #include <utility> | |||
| #include <cfloat> | |||
| #include "ir/base.h" | |||
| #include "ir/dtype.h" | |||
| #include "ir/dtype/number.h" | |||
| using std::fabs; | |||
| namespace mindspore { | |||
| class Scalar : public Value { | |||
| public: | |||
| Scalar() = default; | |||
| @@ -19,9 +19,7 @@ | |||
| #include <memory> | |||
| #include <cmath> | |||
| #include <cfloat> | |||
| #include "pybind_api/api_register.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "utils/convert_utils.h" | |||
| namespace mindspore { | |||
| const ValuePtr ValueSequeue::operator[](const std::size_t &dim) const { | |||
| @@ -208,41 +206,6 @@ bool AnyValue::operator==(const Value &other) const { | |||
| } | |||
| } | |||
| const ValuePtr kAnyValue = std::make_shared<AnyValue>(); | |||
| using ContextPtr = abstract::AnalysisContextPtr; | |||
| abstract::AbstractBasePtr Scalar::ToAbstract() { | |||
| return std::make_shared<abstract::AbstractScalar>(shared_from_base<Value>()); | |||
| } | |||
| abstract::AbstractBasePtr StringImm::ToAbstract() { | |||
| return std::make_shared<abstract::AbstractScalar>(shared_from_base<Value>(), std::make_shared<String>()); | |||
| } | |||
| abstract::AbstractBasePtr RefKey::ToAbstract() { | |||
| auto refkey = std::make_shared<abstract::AbstractRefKey>(); | |||
| refkey->set_value(shared_from_base<Value>()); | |||
| return refkey; | |||
| } | |||
| abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared<abstract::AbstractScalar>(); } | |||
| abstract::AbstractBasePtr ValueTuple::ToAbstract() { | |||
| abstract::AbstractBasePtrList a_list; | |||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { | |||
| MS_EXCEPTION_IF_NULL(ele); | |||
| return ele->ToAbstract(); | |||
| }); | |||
| return std::make_shared<abstract::AbstractTuple>(a_list); | |||
| } | |||
| abstract::AbstractBasePtr ValueList::ToAbstract() { | |||
| abstract::AbstractBasePtrList a_list; | |||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { | |||
| MS_EXCEPTION_IF_NULL(ele); | |||
| return ele->ToAbstract(); | |||
| }); | |||
| return std::make_shared<abstract::AbstractList>(a_list); | |||
| } | |||
| std::size_t ValueSlice::hash() const { | |||
| MS_EXCEPTION_IF_NULL(start_); | |||
| @@ -280,16 +243,6 @@ std::string ValueSlice::ToString() const { | |||
| return buffer.str(); | |||
| } | |||
| abstract::AbstractBasePtr ValueSlice::ToAbstract() { | |||
| MS_EXCEPTION_IF_NULL(start_); | |||
| MS_EXCEPTION_IF_NULL(stop_); | |||
| MS_EXCEPTION_IF_NULL(step_); | |||
| abstract::AbstractBasePtr start = start_->ToAbstract(); | |||
| abstract::AbstractBasePtr end = stop_->ToAbstract(); | |||
| abstract::AbstractBasePtr step = step_->ToAbstract(); | |||
| return std::make_shared<abstract::AbstractSlice>(start, end, step); | |||
| } | |||
| std::size_t KeywordArg::hash() const { | |||
| MS_EXCEPTION_IF_NULL(value_); | |||
| return hash_combine({tid(), std::hash<std::string>{}(key_), value_->hash()}); | |||
| @@ -316,12 +269,6 @@ std::string KeywordArg::ToString() const { | |||
| return buffer.str(); | |||
| } | |||
| abstract::AbstractBasePtr KeywordArg::ToAbstract() { | |||
| MS_EXCEPTION_IF_NULL(value_); | |||
| abstract::AbstractBasePtr argument = value_->ToAbstract(); | |||
| return std::make_shared<abstract::AbstractKeywordArg>(key_, argument); | |||
| } | |||
| const ValuePtr ValueDictionary::operator[](const std::string &key) const { | |||
| auto it = std::find_if(key_values_.begin(), key_values_.end(), | |||
| [key](const std::pair<std::string, ValuePtr> &item) { return item.first == key; }); | |||
| @@ -354,17 +301,4 @@ bool ValueDictionary::operator==(const ValueDictionary &other) const { | |||
| } | |||
| return true; | |||
| } | |||
| abstract::AbstractBasePtr ValueDictionary::ToAbstract() { | |||
| std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv; | |||
| (void)std::transform( | |||
| key_values_.begin(), key_values_.end(), std::back_inserter(kv), | |||
| [](const std::pair<std::string, ValuePtr> &item) { return std::make_pair(item.first, item.second->ToAbstract()); }); | |||
| return std::make_shared<abstract::AbstractDictionary>(kv); | |||
| } | |||
| REGISTER_PYBIND_DEFINE( | |||
| RefKey, ([](const py::module *m) { | |||
| (void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag")); | |||
| })); | |||
| } // namespace mindspore | |||
| @@ -29,6 +29,7 @@ | |||
| #include "ir/anf.h" | |||
| #include "ir/dtype.h" | |||
| #include "ir/scalar.h" | |||
| #include "ir/dtype/ref.h" | |||
| #include "utils/hashing.h" | |||
| #include "common/utils.h" | |||
| @@ -0,0 +1,91 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ir/value.h" | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <cmath> | |||
| #include <cfloat> | |||
| #include "pybind_api/api_register.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| namespace mindspore { | |||
| using ContextPtr = abstract::AnalysisContextPtr; | |||
| abstract::AbstractBasePtr Scalar::ToAbstract() { | |||
| return std::make_shared<abstract::AbstractScalar>(shared_from_base<Value>()); | |||
| } | |||
| abstract::AbstractBasePtr StringImm::ToAbstract() { | |||
| return std::make_shared<abstract::AbstractScalar>(shared_from_base<Value>(), std::make_shared<String>()); | |||
| } | |||
| abstract::AbstractBasePtr RefKey::ToAbstract() { | |||
| auto refkey = std::make_shared<abstract::AbstractRefKey>(); | |||
| refkey->set_value(shared_from_base<Value>()); | |||
| return refkey; | |||
| } | |||
| abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared<abstract::AbstractScalar>(); } | |||
| abstract::AbstractBasePtr ValueTuple::ToAbstract() { | |||
| abstract::AbstractBasePtrList a_list; | |||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { | |||
| MS_EXCEPTION_IF_NULL(ele); | |||
| return ele->ToAbstract(); | |||
| }); | |||
| return std::make_shared<abstract::AbstractTuple>(a_list); | |||
| } | |||
| abstract::AbstractBasePtr ValueList::ToAbstract() { | |||
| abstract::AbstractBasePtrList a_list; | |||
| (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { | |||
| MS_EXCEPTION_IF_NULL(ele); | |||
| return ele->ToAbstract(); | |||
| }); | |||
| return std::make_shared<abstract::AbstractList>(a_list); | |||
| } | |||
| abstract::AbstractBasePtr ValueSlice::ToAbstract() { | |||
| MS_EXCEPTION_IF_NULL(start_); | |||
| MS_EXCEPTION_IF_NULL(stop_); | |||
| MS_EXCEPTION_IF_NULL(step_); | |||
| abstract::AbstractBasePtr start = start_->ToAbstract(); | |||
| abstract::AbstractBasePtr end = stop_->ToAbstract(); | |||
| abstract::AbstractBasePtr step = step_->ToAbstract(); | |||
| return std::make_shared<abstract::AbstractSlice>(start, end, step); | |||
| } | |||
| abstract::AbstractBasePtr KeywordArg::ToAbstract() { | |||
| MS_EXCEPTION_IF_NULL(value_); | |||
| abstract::AbstractBasePtr argument = value_->ToAbstract(); | |||
| return std::make_shared<abstract::AbstractKeywordArg>(key_, argument); | |||
| } | |||
| abstract::AbstractBasePtr ValueDictionary::ToAbstract() { | |||
| std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv; | |||
| (void)std::transform( | |||
| key_values_.begin(), key_values_.end(), std::back_inserter(kv), | |||
| [](const std::pair<std::string, ValuePtr> &item) { return std::make_pair(item.first, item.second->ToAbstract()); }); | |||
| return std::make_shared<abstract::AbstractDictionary>(kv); | |||
| } | |||
| REGISTER_PYBIND_DEFINE( | |||
| RefKey, ([](const py::module *m) { | |||
| (void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag")); | |||
| })); | |||
| } // namespace mindspore | |||
| @@ -26,6 +26,7 @@ | |||
| #include "debug/anf_ir_utils.h" | |||
| #include "proto/onnx.pb.h" | |||
| #include "operator/ops.h" | |||
| #include "ir/param_value_py.h" | |||
| namespace mindspore { | |||
| enum OpMergeMode { | |||
| @@ -424,7 +425,8 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphP | |||
| initializer_proto->set_name(param_ptr->ToString()); | |||
| SetTensorProtoInfo(param_ptr, initializer_proto); | |||
| // set value for initializer | |||
| py::object obj = param_ptr->default_param(); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_ptr->default_param()); | |||
| py::object obj = param_value->value(); | |||
| py::object data = obj.attr("data"); | |||
| if (py::isinstance<tensor::Tensor>(data)) { | |||
| auto method = data.attr("asnumpy"); | |||
| @@ -18,9 +18,10 @@ | |||
| #define MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GET_GRAPH_INFO_H_ | |||
| #include "pybind11/stl.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include "ir/anf.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| py::dict GetParameterLayout(const FuncGraphPtr &graph); | |||
| @@ -19,6 +19,7 @@ | |||
| #include <string> | |||
| #include "ir/anf.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "pipeline/parse/python_adapter.h" | |||
| namespace mindspore { | |||
| @@ -37,7 +38,8 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { | |||
| if (!para_ptr->has_default()) { | |||
| return false; | |||
| } | |||
| return py::cast<bool>(parse::python_adapter::GetPyObjAttr(para_ptr->default_param(), "requires_grad")); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(para_ptr->default_param()); | |||
| return py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad")); | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -28,6 +28,7 @@ | |||
| #include <vector> | |||
| #include "ir/anf.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "ir/meta_tensor.h" | |||
| #include "optimizer/opt.h" | |||
| #include "optimizer/optimizer.h" | |||
| @@ -190,8 +191,8 @@ std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) { | |||
| if (input->isa<Parameter>()) { | |||
| auto input_parameter = input->cast<ParameterPtr>(); | |||
| if (input_parameter->has_default()) { | |||
| bool require_grad = | |||
| py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad")); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(input_parameter->default_param()); | |||
| bool require_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad")); | |||
| is_parameter.push_back(require_grad); | |||
| } else { | |||
| is_parameter.push_back(false); | |||
| @@ -835,8 +836,8 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||
| auto casted_target_parameter = target_parameter->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(casted_target_parameter); | |||
| if (casted_target_parameter->has_default()) { | |||
| bool require_grad = py::cast<bool>( | |||
| parse::python_adapter::GetPyObjAttr(casted_target_parameter->default_param(), "requires_grad")); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(casted_target_parameter->default_param()); | |||
| bool require_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad")); | |||
| is_parameter.push_back(require_grad); | |||
| } else { | |||
| is_parameter.push_back(false); | |||
| @@ -28,6 +28,7 @@ | |||
| #include <utility> | |||
| #include "ir/meta_tensor.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "operator/ops.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "parallel/auto_parallel/graph_costmodel.h" | |||
| @@ -1292,7 +1293,8 @@ bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr ¶meter_nod | |||
| return false; | |||
| } | |||
| py::object clone_info = parse::python_adapter::GetPyObjAttr(cloned_parameter->default_param(), CLONE_INFO); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(cloned_parameter->default_param()); | |||
| py::object clone_info = parse::python_adapter::GetPyObjAttr(param_value->value(), CLONE_INFO); | |||
| bool cloned = py::cast<bool>(parse::python_adapter::GetPyObjAttr(clone_info, CLONED)); | |||
| if (!cloned) { | |||
| return false; | |||
| @@ -1314,7 +1316,8 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||
| } | |||
| // get the cloned index | |||
| py::object cloned_info = parse::python_adapter::GetPyObjAttr(cloned_parameter->default_param(), CLONE_INFO); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(cloned_parameter->default_param()); | |||
| py::object cloned_info = parse::python_adapter::GetPyObjAttr(param_value->value(), CLONE_INFO); | |||
| int32_t cloned_index = py::cast<int32_t>(parse::python_adapter::GetPyObjAttr(cloned_info, CLONED_INDEX)); | |||
| // find the be cloned parameter | |||
| @@ -1329,7 +1332,8 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||
| continue; | |||
| } | |||
| py::object be_cloned_info = parse::python_adapter::GetPyObjAttr(be_cloned_parameter->default_param(), CLONE_INFO); | |||
| auto param_value_cloned = std::dynamic_pointer_cast<ParamValuePy>(be_cloned_parameter->default_param()); | |||
| py::object be_cloned_info = parse::python_adapter::GetPyObjAttr(param_value_cloned->value(), CLONE_INFO); | |||
| if (!py::cast<bool>(parse::python_adapter::GetPyObjAttr(be_cloned_info, BE_CLONED))) { | |||
| continue; | |||
| } | |||
| @@ -2072,9 +2076,9 @@ std::string NodeParameterName(const CNodePtr &node) { | |||
| if (input->isa<Parameter>()) { | |||
| auto input_parameter = input->cast<ParameterPtr>(); | |||
| if (input_parameter->has_default()) { | |||
| if (py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), REQUIRES_GRAD))) { | |||
| return py::cast<std::string>( | |||
| parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), PARAM_NAME)); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(input_parameter->default_param()); | |||
| if (py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), REQUIRES_GRAD))) { | |||
| return py::cast<std::string>(parse::python_adapter::GetPyObjAttr(param_value->value(), PARAM_NAME)); | |||
| } | |||
| } | |||
| } | |||
| @@ -24,6 +24,7 @@ | |||
| #include <functional> | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "parallel/costmodel_context.h" | |||
| #include "parallel/context.h" | |||
| #include "pipeline/pass.h" | |||
| @@ -225,8 +226,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||
| for (const auto ¶m : func_graph->parameters()) { | |||
| auto param_node = std::static_pointer_cast<Parameter>(param); | |||
| if (param_node->has_default()) { | |||
| AbstractBasePtr ptr = | |||
| abstract::FromValue(parse::data_converter::PyDataToValue(param_node->default_param()), true); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param()); | |||
| AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true); | |||
| parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); | |||
| args_spec.push_back(ptr); | |||
| @@ -25,8 +25,11 @@ | |||
| #include "operator/ops.h" | |||
| #include "debug/info.h" | |||
| #include "debug/trace.h" | |||
| #include "pybind11/pybind11.h" | |||
| namespace mindspore { | |||
| namespace py = pybind11; | |||
| namespace parse { | |||
| FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) { | |||
| func_graph_ = std::make_shared<FuncGraph>(); | |||
| @@ -18,11 +18,13 @@ | |||
| #define PIPELINE_PARSE_PARSE_BASE_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "pybind11/pybind11.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/manager.h" | |||
| #include "pybind_api/export_flags.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| namespace parse { | |||
| // define the node type | |||
| @@ -21,6 +21,7 @@ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "ir/param_value_py.h" | |||
| #include "pipeline/parse/data_converter.h" | |||
| #include "pipeline/parse/parse.h" | |||
| #include "pipeline/parse/python_adapter.h" | |||
| @@ -101,8 +102,9 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object | |||
| } | |||
| } | |||
| if (para_node == nullptr) { | |||
| ParameterPtr node = top_graph->AddWeightParameter(param_name); | |||
| node->set_default_param(obj); | |||
| auto node = top_graph->AddWeightParameter(param_name); | |||
| auto param_value_new = std::make_shared<ParamValuePy>(obj); | |||
| node->set_default_param(param_value_new); | |||
| // set_abstract for parameter | |||
| auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input")); | |||
| @@ -24,6 +24,7 @@ | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include "ir/param_value_py.h" | |||
| #include "pipeline/pass.h" | |||
| #include "pipeline/parse/data_converter.h" | |||
| #include "optimizer/ad/dfunctor.h" | |||
| @@ -619,7 +620,12 @@ void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, V | |||
| // maybe some default parameter | |||
| for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) { | |||
| MS_EXCEPTION_IF_NULL(graph_params[i]); | |||
| py::object obj = dyn_cast<Parameter>(graph_params[i])->default_param(); | |||
| auto param_ptr = (graph_params[i])->cast<ParameterPtr>(); | |||
| if (!param_ptr->has_default()) { | |||
| MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param"; | |||
| } | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_ptr->default_param()); | |||
| py::object obj = param_value->value(); | |||
| py::object p_value = py::cast<py::object>(parse::python_adapter::GetPyObjAttr(obj, "default_input")); | |||
| (*arg_list).push_back(p_value); | |||
| } | |||
| @@ -20,6 +20,7 @@ | |||
| #include <unordered_set> | |||
| #include "common/utils.h" | |||
| #include "operator/ops.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "device/kernel_info.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| @@ -232,7 +233,9 @@ ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { | |||
| new_parameter->set_abstract(parameter->abstract()); | |||
| new_parameter->set_name(parameter->name()); | |||
| if (AnfAlgo::IsParameterWeight(parameter)) { | |||
| new_parameter->set_default_param(parameter->default_param()); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param()); | |||
| auto param_value_new = std::make_shared<ParamValuePy>(param_value->value()); | |||
| new_parameter->set_default_param(param_value_new); | |||
| kernel_info->SetFeatureMapFlag(false); | |||
| } else { | |||
| kernel_info->SetFeatureMapFlag(true); | |||
| @@ -20,6 +20,7 @@ | |||
| #include <unordered_set> | |||
| #include "pipeline/parse/data_converter.h" | |||
| #include "ir/manager.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "operator/ops.h" | |||
| #include "common/trans.h" | |||
| #include "utils/context/ms_context.h" | |||
| @@ -44,10 +45,11 @@ tensor::TensorPtr GetParamDefaultInputTensor(const AnfNodePtr &node) { | |||
| return nullptr; | |||
| } | |||
| auto parameter = node->cast<ParameterPtr>(); | |||
| if (parameter == nullptr) { | |||
| if (parameter == nullptr || !parameter->has_default()) { | |||
| return nullptr; | |||
| } | |||
| auto py_param = parameter->default_param(); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param()); | |||
| auto py_param = param_value->value(); | |||
| if (!py::hasattr(py_param, "default_input")) { | |||
| return nullptr; | |||
| } | |||
| @@ -319,7 +321,8 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, | |||
| MS_EXCEPTION_IF_NULL(param); | |||
| if (tensor_mask == kParameterWeightTensorMask) { | |||
| py::object obj; | |||
| param->set_default_param(obj); | |||
| auto param_value_new = std::make_shared<ParamValuePy>(obj); | |||
| param->set_default_param(param_value_new); | |||
| } | |||
| // set the kernel info of parameter | |||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| @@ -25,9 +25,11 @@ | |||
| #include <sstream> | |||
| #include <utility> | |||
| #include <iterator> | |||
| #include "pybind11/pybind11.h" | |||
| #include "ir/value.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| class BaseRef; | |||
| class VectorRef; | |||
| @@ -16,6 +16,7 @@ | |||
| #include "utils/callbacks_ge.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "transform/df_graph_manager.h" | |||
| #include "transform/util.h" | |||
| #include "pipeline/parse/data_converter.h" | |||
| @@ -49,7 +50,11 @@ bool GetParameterShape(const FuncGraphPtr &graph, const std::string ¶m_name, | |||
| return false; | |||
| } | |||
| if (param_node->name() == param_name) { | |||
| py::object parameter = param_node->default_param(); | |||
| py::object parameter; | |||
| if (param_node->has_default()) { | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param()); | |||
| parameter = param_value->value(); | |||
| } | |||
| ValuePtr value = parse::data_converter::PyDataToValue(parameter); | |||
| TensorPtr tensor = std::dynamic_pointer_cast<tensor::Tensor>(value); | |||
| if (tensor == nullptr) { | |||
| @@ -19,7 +19,9 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "ir/primitive.h" | |||
| #include "ir/manager.h" | |||
| #include "ir/func_graph.h" | |||
| #include "pipeline/parse/parse_base.h" | |||
| #include "pipeline/parse/parse.h" | |||
| #include "./common.h" | |||
| @@ -16,6 +16,10 @@ | |||
| #include <iostream> | |||
| #include "common/common_test.h" | |||
| #include "ir/dtype.h" | |||
| #include "ir/dtype/ref.h" | |||
| #include "ir/dtype/number.h" | |||
| #include "ir/dtype/container.h" | |||
| #include "ir/dtype/empty.h" | |||
| namespace mindspore { | |||
| class TestDType : public UT::Common { | |||
| @@ -92,21 +92,22 @@ class TestTensor : public UT::Common { | |||
| TestTensor() {} | |||
| virtual void SetUp() { | |||
| UT::InitPythonPath(); | |||
| // Init tensor data by py::array_t<float> | |||
| input_ = py::array_t<float, py::array::c_style>({2, 3}); | |||
| auto array = input_.mutable_unchecked(); | |||
| float start = 0; | |||
| for (int i = 0; i < array.shape(0); i++) { | |||
| for (int j = 0; j < array.shape(1); j++) { | |||
| array(i, j) = start++; | |||
| } | |||
| } | |||
| } | |||
| protected: | |||
| py::array_t<float, py::array::c_style> input_; | |||
| }; | |||
| py::array_t<float, py::array::c_style> BuildInputTensor() { | |||
| // Init tensor data by py::array_t<float> | |||
| py::array_t<float, py::array::c_style> input = py::array_t<float, py::array::c_style>({2, 3}); | |||
| auto array = input.mutable_unchecked(); | |||
| float start = 0; | |||
| for (int i = 0; i < array.shape(0); i++) { | |||
| for (int j = 0; j < array.shape(1); j++) { | |||
| array(i, j) = start++; | |||
| } | |||
| } | |||
| return input; | |||
| } | |||
| TEST_F(TestTensor, PyArrayScalarTest) { | |||
| std::vector<int> dimensions; | |||
| py::array data = py::array_t<int64_t, py::array::c_style>(dimensions); | |||
| @@ -246,7 +247,7 @@ TEST_F(TestTensor, PyArrayTest) { | |||
| TEST_F(TestTensor, InitByFloatArrayDataCTest) { | |||
| // Init tensor data by py::array_t<float> | |||
| TensorPtr tensor = std::make_shared<Tensor>(input_); | |||
| auto tensor = std::make_shared<Tensor>(BuildInputTensor()); | |||
| // Print some information of the tensor | |||
| std::cout << "Datatype: " << tensor->data_type() << std::endl; | |||
| @@ -268,7 +269,7 @@ TEST_F(TestTensor, InitByFloatArrayDataCTest) { | |||
| TEST_F(TestTensor, InitByFloatArrayDataTest) { | |||
| // Init tensor data by py::array_t<float> | |||
| TensorPtr tensor = std::make_shared<Tensor>(input_); | |||
| TensorPtr tensor = std::make_shared<Tensor>(BuildInputTensor()); | |||
| // Print some information of the tensor | |||
| std::cout << "Datatype: " << tensor->data_type() << std::endl; | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "common/common_test.h" | |||
| #include "ir/dtype/number.h" | |||
| #include "parallel/device_manager.h" | |||
| #include "parallel/auto_parallel/edge_costmodel.h" | |||
| #include "parallel/ops_info/matmul_info.h" | |||
| @@ -14,6 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "parallel/tensor_layout/util_layout_gen_test.h" | |||
| #include <cmath> | |||
| #include <map> | |||
| #include <tuple> | |||
| #include <vector> | |||
| @@ -23,6 +24,8 @@ | |||
| #include "parallel/tensor_layout/shape_util.h" | |||
| #include "common/common_test.h" | |||
| using std::pow; | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| std::vector<std::vector<int32_t>> combine(const std::vector<int32_t>& in, int32_t target) { | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "common/common_test.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "operator/ops.h" | |||
| #include "session/kernel_graph.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| @@ -765,7 +766,8 @@ TEST_F(AnfRuntimeAlgorithmTest, IsParameterWeight) { | |||
| py::object obj; | |||
| auto parameter_node = kernel_graph->add_parameter(); | |||
| MS_EXCEPTION_IF_NULL(parameter_node); | |||
| parameter_node->set_default_param(obj); | |||
| auto param_value_new = std::make_shared<ParamValuePy>(obj); | |||
| parameter_node->set_default_param(param_value_new); | |||
| EXPECT_TRUE(AnfAlgo::IsParameterWeight(parameter_node)); | |||
| EXPECT_THROW(AnfAlgo::IsParameterWeight(nullptr), std::runtime_error); | |||
| } | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "common/common_test.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "operator/ops.h" | |||
| #include "session/kernel_graph.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| @@ -82,7 +83,8 @@ TEST_F(KernelGraphTest, NewParameter) { | |||
| auto weight_parameter_node = anf_graph->add_parameter(); | |||
| MS_EXCEPTION_IF_NULL(weight_parameter_node); | |||
| py::object obj; | |||
| weight_parameter_node->set_default_param(obj); | |||
| auto param_value_new = std::make_shared<ParamValuePy>(obj); | |||
| weight_parameter_node->set_default_param(param_value_new); | |||
| weight_parameter_node->set_abstract(x_abstract); | |||
| auto new_weight_parameter_node = kernel_graph->NewParameter(weight_parameter_node); | |||
| EXPECT_NE(new_weight_parameter_node, nullptr); | |||