From 40e15996b06d9689567fe890e83ecc6c96007bc7 Mon Sep 17 00:00:00 2001 From: leopz Date: Tue, 19 May 2020 00:40:32 +0800 Subject: [PATCH] move default_param out of parameter and remove pybind11 in anf define --- mindspore/ccsrc/debug/anf_ir_utils.cc | 7 +- mindspore/ccsrc/debug/draw.cc | 31 +- mindspore/ccsrc/debug/info.cc | 4 +- mindspore/ccsrc/debug/info.h | 3 - mindspore/ccsrc/debug/label.h | 1 - mindspore/ccsrc/debug/trace_info.cc | 2 - mindspore/ccsrc/debug/trace_info.h | 3 - mindspore/ccsrc/ir/anf.cc | 78 +-- mindspore/ccsrc/ir/anf.h | 21 +- mindspore/ccsrc/ir/anf_extends.cc | 103 ++++ mindspore/ccsrc/ir/dtype.cc | 461 ----------------- mindspore/ccsrc/ir/dtype/container.cc | 3 - mindspore/ccsrc/ir/dtype/number.cc | 3 - mindspore/ccsrc/ir/dtype/ref.cc | 3 - mindspore/ccsrc/ir/dtype/type.cc | 10 +- mindspore/ccsrc/ir/dtype/type_extends.cc | 25 + mindspore/ccsrc/ir/dtype_extends.cc | 484 ++++++++++++++++++ mindspore/ccsrc/ir/func_graph_cloner.cc | 9 +- mindspore/ccsrc/ir/func_graph_cloner.h | 1 + mindspore/ccsrc/ir/meta_func_graph.h | 1 + mindspore/ccsrc/ir/named.h | 1 - mindspore/ccsrc/ir/param_value_minnie.h | 43 ++ mindspore/ccsrc/ir/param_value_py.h | 43 ++ mindspore/ccsrc/ir/scalar.h | 9 +- mindspore/ccsrc/ir/value.cc | 68 +-- mindspore/ccsrc/ir/value.h | 1 + mindspore/ccsrc/ir/value_extends.cc | 91 ++++ mindspore/ccsrc/onnx/onnx_exporter.cc | 4 +- .../parallel/graph_util/get_parallel_info.h | 3 +- .../ccsrc/parallel/graph_util/node_info.cc | 4 +- .../ccsrc/parallel/step_auto_parallel.cc | 9 +- mindspore/ccsrc/parallel/step_parallel.cc | 16 +- mindspore/ccsrc/pipeline/action.cc | 5 +- .../ccsrc/pipeline/parse/function_block.cc | 3 + mindspore/ccsrc/pipeline/parse/parse_base.h | 2 + mindspore/ccsrc/pipeline/parse/resolve.cc | 6 +- mindspore/ccsrc/pipeline/pipeline.cc | 8 +- mindspore/ccsrc/session/kernel_graph.cc | 5 +- mindspore/ccsrc/session/session_basic.cc | 9 +- mindspore/ccsrc/utils/base_ref.h | 4 +- mindspore/ccsrc/utils/callbacks_ge.cc | 7 +- tests/ut/cpp/common/py_func_graph_fetcher.h | 2 + tests/ut/cpp/ir/dtype_test.cc | 4 + tests/ut/cpp/ir/meta_tensor_test.cc | 29 +- .../auto_parallel/edge_costmodel_test.cc | 1 + .../tensor_layout/util_layout_gen_test.cc | 3 + .../cpp/session/anf_runtime_algorithm_test.cc | 4 +- tests/ut/cpp/session/kernel_graph_test.cc | 4 +- 48 files changed, 943 insertions(+), 698 deletions(-) create mode 100644 mindspore/ccsrc/ir/anf_extends.cc create mode 100644 mindspore/ccsrc/ir/dtype/type_extends.cc create mode 100644 mindspore/ccsrc/ir/dtype_extends.cc create mode 100644 mindspore/ccsrc/ir/param_value_minnie.h create mode 100644 mindspore/ccsrc/ir/param_value_py.h create mode 100644 mindspore/ccsrc/ir/value_extends.cc mode change 100755 => 100644 mindspore/ccsrc/session/kernel_graph.cc diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index 6ebe3ad43f..5b80265453 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -26,6 +26,7 @@ #include "utils/graph_utils.h" #include "utils/symbolic.h" #include "ir/meta_func_graph.h" +#include "ir/param_value_py.h" #include "pipeline/parse/python_adapter.h" #include "pipeline/parse/resolve.h" #include "operator/composite/composite.h" @@ -469,7 +470,8 @@ void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vectorhas_default()) { - ofs << " = @" << DumpObject(param_ptr->default_param(), "D"); + auto param_value = std::dynamic_pointer_cast(param_ptr->default_param()); + ofs << " = @" << DumpObject(param_value->value(), "D"); } // output comment @@ -1650,7 +1652,8 @@ class IrParser { // load parameter default value from serialized file py::object default_obj = LoadObject(lexer_.GetTokenText()); - param->set_default_param(default_obj); + auto param_value_new = std::make_shared(default_obj); + param->set_default_param(param_value_new); tok = lexer_.GetNextToken(); } diff --git a/mindspore/ccsrc/debug/draw.cc b/mindspore/ccsrc/debug/draw.cc index 70016e1995..39fae1d28f 100644 --- a/mindspore/ccsrc/debug/draw.cc +++ b/mindspore/ccsrc/debug/draw.cc @@ -21,12 +21,17 @@ #include #include +#include "pybind11/pybind11.h" #include "ir/meta_func_graph.h" +#include "ir/param_value_py.h" +#include "ir/primitive.h" #include "utils/graph_utils.h" #include "utils/utils.h" #include "operator/composite/composite.h" #include "ir/meta_tensor.h" +namespace py = pybind11; + namespace mindspore { // namespace to support debug utils @@ -312,17 +317,21 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) { for (auto ¶meter : key->parameters()) { buffer_ << ""; buffer_ << parameter->ToString(); - auto py_p = dyn_cast(parameter)->default_param(); - if (py::hasattr(py_p, "default_input")) { - py_p = py_p.attr("default_input"); - if (py::hasattr(py_p, PYTHON_TENSOR_FLAG)) { - std::shared_ptr m_tensor = py_p.cast>(); - py::tuple shape = m_tensor->GetPyTupleShape(); - buffer_ << "[" << std::string(py::str(shape)) << "]"; - } else if (py::hasattr(py_p, PYTHON_META_TENSOR_FLAG)) { - std::shared_ptr m_tensor = py_p.cast>(); - py::tuple shape = m_tensor->GetPyTupleShape(); - buffer_ << "[" << std::string(py::str(shape)) << "]"; + auto param = parameter->cast(); + if (param->has_default()) { + auto param_value = std::dynamic_pointer_cast(param->default_param()); + auto py_p = param_value->value(); + if (py::hasattr(py_p, "default_input")) { + py_p = py_p.attr("default_input"); + if (py::hasattr(py_p, PYTHON_TENSOR_FLAG)) { + auto m_tensor = py_p.cast>(); + py::tuple shape = m_tensor->GetPyTupleShape(); + buffer_ << "[" << py::str(shape) << "]"; + } else if (py::hasattr(py_p, PYTHON_META_TENSOR_FLAG)) { + auto m_tensor = py_p.cast>(); + py::tuple shape = m_tensor->GetPyTupleShape(); + buffer_ << "[" << py::str(shape) << "]"; + } } } buffer_ << ""; diff --git a/mindspore/ccsrc/debug/info.cc b/mindspore/ccsrc/debug/info.cc index 7903e554d9..770192a81d 100644 --- a/mindspore/ccsrc/debug/info.cc +++ b/mindspore/ccsrc/debug/info.cc @@ -18,9 +18,9 @@ #include #include #include +#include #include "ir/anf.h" -#include "pipeline/parse/parse.h" -#include "pipeline/parse/python_adapter.h" +#include "utils/convert_utils.h" namespace mindspore { std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) { diff --git a/mindspore/ccsrc/debug/info.h b/mindspore/ccsrc/debug/info.h index e8d02827d8..9ed216277e 100644 --- a/mindspore/ccsrc/debug/info.h +++ b/mindspore/ccsrc/debug/info.h @@ -24,13 +24,10 @@ #include #include -#include "pybind11/pybind11.h" #include "ir/base.h" #include "debug/trace_info.h" namespace mindspore { -namespace py = pybind11; - // namespace to support intermediate representation definition enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSourceLineTipInLine = 2 }; diff --git a/mindspore/ccsrc/debug/label.h b/mindspore/ccsrc/debug/label.h index 38509e8b55..62b98defd3 100644 --- a/mindspore/ccsrc/debug/label.h +++ b/mindspore/ccsrc/debug/label.h @@ -21,7 +21,6 @@ #include #include #include -#include "utils/any.h" #include "ir/anf.h" namespace mindspore { diff --git a/mindspore/ccsrc/debug/trace_info.cc b/mindspore/ccsrc/debug/trace_info.cc index 19358e197a..048bf2bdf0 100644 --- a/mindspore/ccsrc/debug/trace_info.cc +++ b/mindspore/ccsrc/debug/trace_info.cc @@ -19,8 +19,6 @@ #include #include #include "ir/anf.h" -#include "pipeline/parse/parse.h" -#include "pipeline/parse/python_adapter.h" namespace mindspore { std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) { diff --git a/mindspore/ccsrc/debug/trace_info.h b/mindspore/ccsrc/debug/trace_info.h index 7d7b7d44b3..63551ba10e 100644 --- a/mindspore/ccsrc/debug/trace_info.h +++ b/mindspore/ccsrc/debug/trace_info.h @@ -24,12 +24,9 @@ #include #include -#include "pybind11/pybind11.h" #include "ir/base.h" namespace mindspore { -namespace py = pybind11; - class TraceInfo; using TraceInfoPtr = std::shared_ptr; class Location; diff --git a/mindspore/ccsrc/ir/anf.cc b/mindspore/ccsrc/ir/anf.cc index 50fe184d3f..8a400d19d4 100644 --- a/mindspore/ccsrc/ir/anf.cc +++ b/mindspore/ccsrc/ir/anf.cc @@ -23,21 +23,11 @@ #include #include -#include "ir/visitor.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "operator/ops.h" -#include "parallel/ops_info/ops_utils.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" namespace mindspore { // namespace to support intermediate representation definition -// Methods of AnfNode -TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); } -BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } - -std::string AnfNode::ToString() const { - return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); -} - CNode::CNode(const std::vector &inputs, const FuncGraphPtr &func_graph) : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} @@ -85,66 +75,6 @@ std::string CNode::DebugString(int recursive_level) const { return buffer.str(); } -OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { - if (operator_info_ != nullptr) { - MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() - << ", using the new one: " << operator_info->name(); - auto old_ptr = operator_info_; - operator_info_ = operator_info; - return old_ptr; - } - operator_info_ = operator_info; - return nullptr; -} - -std::string CNode::fullname_with_scope() { - // if full name is set, return its name immediately - if (!fullname_with_scope_.empty()) { - return fullname_with_scope_; - } - - if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) || - IsApply(prim::kPrimHistogramSummary)) { - std::string tag = GetValue(GetValueNode(input(1))); - if (tag == "") { - MS_LOG(EXCEPTION) << "The tag name is null, should be valid string"; - } - std::string name; - if (IsApply(prim::kPrimScalarSummary)) { - name = tag + "[:Scalar]"; - } else if (IsApply(prim::kPrimImageSummary)) { - name = tag + "[:Image]"; - } else if (IsApply(prim::kPrimHistogramSummary)) { - name = tag + "[:Histogram]"; - } else { - name = tag + "[:Tensor]"; - } - fullname_with_scope_ = name; - } else { - // cnode input 0 should be primitive ptr - auto value_ptr = input(0)->cast(); - if (value_ptr == nullptr) { - MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << "."; - fullname_with_scope_ = id_generator::get_id(shared_from_base()); - return fullname_with_scope_; - } - auto input_value = value_ptr->value(); - if (input_value == nullptr) { - MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr."; - fullname_with_scope_ = id_generator::get_id(shared_from_base()); - return fullname_with_scope_; - } - - PrimitivePtr prim = GetValue(input_value); - MS_EXCEPTION_IF_NULL(scope()); - MS_EXCEPTION_IF_NULL(prim); - fullname_with_scope_ = - scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base()); - } - - return fullname_with_scope_; -} - std::string ValueNode::ToString() const { MS_EXCEPTION_IF_NULL(value_); if (value_->isa()) { @@ -173,10 +103,6 @@ std::string ValueNode::fullname_with_scope() { return fullname_with_scope_; } -void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } -void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } -void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } - bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); diff --git a/mindspore/ccsrc/ir/anf.h b/mindspore/ccsrc/ir/anf.h index d3da155b50..c2db17aec5 100644 --- a/mindspore/ccsrc/ir/anf.h +++ b/mindspore/ccsrc/ir/anf.h @@ -52,6 +52,7 @@ class AbstractBase; } // namespace abstract using BaseShapePtr = std::shared_ptr; using AbstractBasePtr = std::shared_ptr; +using AbstractBasePtrList = std::vector; class ValueNode; using ValueNodePtr = std::shared_ptr; @@ -78,6 +79,13 @@ using KernelInfoDevicePtr = std::shared_ptr; class AnfVisitor; +class ParamValue { + public: + ParamValue() = default; + virtual ~ParamValue() = default; +}; +using ParamValuePtr = std::shared_ptr; + // AnfNode is the basic class of the IR definition derived from Base. // Only two types of nodes are derived: CNode and ANode. // Methods: @@ -239,11 +247,11 @@ class ANode : public AnfNode { // Parameter represents the parameter inputs of a function. They have no value. // Attributes: -// default_param_: used to hold the inputting tensor of the model. +// default_param_value_: used to hold the inputting tensor of the model. class Parameter : public ANode { public: explicit Parameter(const FuncGraphPtr &func_graph) - : ANode(func_graph), name_(""), has_default_(false), default_param_(py::none()), tensor_layout_(nullptr) {} + : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), tensor_layout_(nullptr) {} ~Parameter() override = default; MS_DECLARE_PARENT(Parameter, ANode); @@ -254,12 +262,11 @@ class Parameter : public ANode { std::string fullname_with_scope() override { return name(); }; bool has_default() const { return has_default_; } - - py::object default_param() { return default_param_; } - void set_default_param(const py::object &obj) { - default_param_ = obj; + void set_default_param(ParamValuePtr param) { + default_param_ = param; has_default_ = true; } + ParamValuePtr default_param() const { return default_param_; } std::shared_ptr tensor_layout() const { return tensor_layout_; } void set_tensor_layout(const std::shared_ptr &tensor_layout) { @@ -280,7 +287,7 @@ class Parameter : public ANode { private: std::string name_; bool has_default_; - py::object default_param_; + ParamValuePtr default_param_; std::shared_ptr tensor_layout_; }; using ParameterPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/ir/anf_extends.cc b/mindspore/ccsrc/ir/anf_extends.cc new file mode 100644 index 0000000000..2dd1a6e2f4 --- /dev/null +++ b/mindspore/ccsrc/ir/anf_extends.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/anf.h" + +#include +#include +#include +#include + +#include "ir/visitor.h" +#include "pipeline/static_analysis/static_analysis.h" +#include "operator/ops.h" +#include "parallel/ops_info/ops_utils.h" + +namespace mindspore { +// namespace to support intermediate representation definition +// Methods of AnfNode +TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); } +BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } + +std::string AnfNode::ToString() const { + return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); +} + +OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { + if (operator_info_ != nullptr) { + MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() + << ", using the new one: " << operator_info->name(); + auto old_ptr = operator_info_; + operator_info_ = operator_info; + return old_ptr; + } + operator_info_ = operator_info; + return nullptr; +} + +std::string CNode::fullname_with_scope() { + // if full name is set, return its name immediately + if (!fullname_with_scope_.empty()) { + return fullname_with_scope_; + } + + if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) || + IsApply(prim::kPrimHistogramSummary)) { + std::string tag = GetValue(GetValueNode(input(1))); + if (tag == "") { + MS_LOG(EXCEPTION) << "The tag name is null, should be valid string"; + } + std::string name; + if (IsApply(prim::kPrimScalarSummary)) { + name = tag + "[:Scalar]"; + } else if (IsApply(prim::kPrimImageSummary)) { + name = tag + "[:Image]"; + } else if (IsApply(prim::kPrimHistogramSummary)) { + name = tag + "[:Histogram]"; + } else { + name = tag + "[:Tensor]"; + } + fullname_with_scope_ = name; + } else { + // cnode input 0 should be primitive ptr + auto value_ptr = input(0)->cast(); + if (value_ptr == nullptr) { + MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << "."; + fullname_with_scope_ = id_generator::get_id(shared_from_base()); + return fullname_with_scope_; + } + auto input_value = value_ptr->value(); + if (input_value == nullptr) { + MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr."; + fullname_with_scope_ = id_generator::get_id(shared_from_base()); + return fullname_with_scope_; + } + + PrimitivePtr prim = GetValue(input_value); + MS_EXCEPTION_IF_NULL(scope()); + MS_EXCEPTION_IF_NULL(prim); + fullname_with_scope_ = + scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base()); + } + + return fullname_with_scope_; +} + +void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } + +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/dtype.cc b/mindspore/ccsrc/ir/dtype.cc index 968ee9a524..ef3ce14bee 100644 --- a/mindspore/ccsrc/ir/dtype.cc +++ b/mindspore/ccsrc/ir/dtype.cc @@ -19,9 +19,6 @@ #include #include #include "utils/log_adapter.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" namespace mindspore { TypePtr Keyword::DeepCopy() const { @@ -206,8 +203,6 @@ std::string Function::ToString() const { return buffer.str(); } -TypePtr TypeAnything::DeepCopy() const { return kAnyType; } - TypePtr JTagged::DeepCopy() const { MS_EXCEPTION_IF_NULL(subtype_); if (IsGeneric()) { @@ -247,460 +242,4 @@ std::ostream &operator<<(std::ostream &os, const std::shared_ptr proble os << problem->ToString(); return os; } - -std::size_t TypeHasher::operator()(TypePtr const &type) const { - MS_EXCEPTION_IF_NULL(type); - std::size_t hash = std::hash()(type->type_id()); - return hash; -} - -std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const { - std::size_t hash_sum = 0; - for (auto &type : type_list) { - auto type_id = static_cast(type->type_id()); - hash_sum = hash_combine(hash_sum, type_id); - } - return hash_sum; -} - -bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const { - MS_EXCEPTION_IF_NULL(t1); - MS_EXCEPTION_IF_NULL(t2); - return t1->type_id() == t2->type_id(); -} - -bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const { - if (lhs.size() != rhs.size()) { - return false; - } - std::size_t size = lhs.size(); - for (std::size_t i = 0; i < size; ++i) { - MS_EXCEPTION_IF_NULL(lhs[i]); - MS_EXCEPTION_IF_NULL(rhs[i]); - if (*lhs[i] != *rhs[i]) { - return false; - } - } - return true; -} - -TypePtr TypeIdToType(TypeId id) { - switch (id) { - case kNumberTypeFloat16: - return kFloat16; - case kNumberTypeFloat: - case kNumberTypeFloat32: - return kFloat32; - case kNumberTypeFloat64: - return kFloat64; - case kNumberTypeInt8: - return kInt8; - case kNumberTypeInt16: - return kInt16; - case kNumberTypeInt32: - return kInt32; - case kNumberTypeInt64: - return kInt64; - case kNumberTypeUInt8: - return kUInt8; - case kNumberTypeUInt16: - return kUInt16; - case kNumberTypeUInt32: - return kUInt32; - case kNumberTypeUInt64: - return kUInt64; - case kNumberTypeBool: - return kBool; - case kMetaTypeExternal: - return kTypeExternal; - case kMetaTypeAnything: - return kAnyType; - case kMetaTypeNone: - return kTypeNone; - case kObjectTypeEnvType: - return kTypeEnv; - case kObjectTypeRefKey: - return kRefKeyType; - case kObjectTypeRef: - return kRefType; - case kTypeUnknown: - return kTypeNone; - default: - MS_LOG(EXCEPTION) << "Not support the type: " << id; - } -} - -namespace { -template -TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) { - TypePtr type = nullptr; - if (type_name == num_type_name) { - type = std::make_shared(); - } else { - try { - if (num_type_name.size() >= type_name.size()) { - MS_LOG(EXCEPTION) << "Convert type is error, type_name(" << type_name << "), num_type_name(" << num_type_name - << ")"; - } - auto bits = std::stoi(type_name.substr(num_type_name.size())); - type = std::make_shared(bits); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << num_type_name << " convert from string error " << e.what(); - } - } - return type; -} - -std::vector StringToVectorOfType(const std::string &type_names) { - std::vector types; - if (type_names.length() == 0) { - return types; - } - std::string::size_type start = 0; - std::string::size_type end = type_names.find_first_of(','); - while (end != std::string::npos) { - types.push_back(StringToType(type_names.substr(start, end))); - // Skip ',' to find the next element. - start = end + 1; - end = type_names.find_first_of(',', start); - } - if (start >= type_names.size()) { - MS_LOG(EXCEPTION) << "Type name is empty string."; - } - types.push_back(StringToType(type_names.substr(start))); - return types; -} - -TypePtr TensorStrToType(const std::string &type_name) { - TypePtr type = nullptr; - if (type_name == "Tensor") { - type = std::make_shared(); - } else { - try { - auto start = type_name.find_first_of('[') + 1; - auto end = type_name.find_last_of(']'); - if (start >= type_name.size()) { - return nullptr; - } - auto element_str = type_name.substr(start, end - start); - auto element_type = StringToType(element_str); - if (element_type == nullptr) { - return nullptr; - } - type = std::make_shared(element_type); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); - } - } - - return type; -} - -TypePtr ListStrToType(const std::string &type_name) { - TypePtr type = nullptr; - if (type_name == "List") { - type = std::make_shared(); - } else { - try { - auto start = type_name.find_first_of('[') + 1; - auto end = type_name.find_last_of(']'); - if (start >= type_name.size()) { - return nullptr; - } - std::string element_strs = type_name.substr(start, end - start); - std::vector element_types = StringToVectorOfType(element_strs); - bool wrong = - std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); - if (wrong) { - return nullptr; - } - type = std::make_shared(element_types); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); - } - } - - return type; -} - -TypePtr TupleStrToType(const std::string &type_name) { - TypePtr type = nullptr; - if (type_name == "Tuple") { - type = std::make_shared(); - } else { - try { - size_t start = type_name.find_first_of('[') + 1; - size_t end = type_name.find_last_of(']'); - if (start >= type_name.size()) { - return nullptr; - } - std::string element_strs = type_name.substr(start, end - start); - std::vector element_types = StringToVectorOfType(element_strs); - bool wrong = - std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); - if (wrong) { - return nullptr; - } - type = std::make_shared(element_types); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); - } - } - return type; -} - -TypePtr FunctionStrToType(const std::string &type_name) { - TypePtr type = nullptr; - - if (type_name == "Function") { - type = std::make_shared(); - } else { - try { - // format: [(para1, para2, para3, ...) retval] - size_t start = type_name.find_first_of('[') + 1; - size_t end = type_name.find_last_of(']'); - if (start >= type_name.size()) { - return nullptr; - } - std::string str_all = type_name.substr(start, end - start); - size_t start_a = str_all.find_first_of('(') + 1; - size_t end_a = str_all.find_last_of(')'); - if (start_a >= str_all.size()) { - return nullptr; - } - std::string str_args = str_all.substr(start_a, end_a - start_a); - // bypass " " between ")" and retval - start = end_a + 2; - if (start >= str_all.size()) { - return nullptr; - } - std::string str_retval = str_all.substr(start); - - std::vector args_type = StringToVectorOfType(str_args); - TypePtr retval = StringToType(str_retval); - bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; }); - if (retval == nullptr || wrong) { - return nullptr; - } - type = std::make_shared(args_type, retval); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); - } - } - return type; -} -} // namespace - -TypePtr StringToType(const std::string &type_name) { - TypePtr type = nullptr; - if (type_name.compare("None") == 0) { - type = std::make_shared(); - } else if (type_name.compare("Ellipsis") == 0) { - type = std::make_shared(); - } else if (type_name.compare("TypeType") == 0) { - type = std::make_shared(); - } else if (type_name.compare("SymbolicKeyType") == 0) { - type = std::make_shared(); - } else if (type_name.compare("RefKeyType") == 0) { - type = std::make_shared(); - } else if (type_name.compare("EnvType") == 0) { - type = std::make_shared(); - } else if (type_name.compare("Number") == 0) { - type = std::make_shared(); - } else if (type_name.compare("Bool") == 0) { - type = std::make_shared(); - } else if (type_name.compare(0, strlen("Int"), "Int") == 0) { - type = StringToNumberType(type_name, "Int"); - } else if (type_name.compare(0, strlen("UInt"), "UInt") == 0) { - type = StringToNumberType(type_name, "UInt"); - } else if (type_name.compare(0, strlen("Float"), "Float") == 0) { - type = StringToNumberType(type_name, "Float"); - } else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) { - type = TensorStrToType(type_name); - } else if (type_name.compare(0, strlen("List"), "List") == 0) { - type = ListStrToType(type_name); - } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { - type = TupleStrToType(type_name); - } else if (type_name.compare("Slice") == 0) { - type = std::make_shared(); - } else if (type_name.compare("Dictionary") == 0) { - type = std::make_shared(); - } else if (type_name.compare("String") == 0) { - type = std::make_shared(); - } else if (type_name.compare("Problem") == 0) { - type = std::make_shared(); - } else if (type_name.compare(0, strlen("Function"), "Function") == 0) { - type = FunctionStrToType(type_name); - } else { - // - unsupported to convert - // Class - // SymbolicType - // JTagged - // Anything - // External - // Problem - MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!"; - } - return type; -} - -bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { - if (x == nullptr || base_type == nullptr) { - MS_LOG(ERROR) << "Type is nullptr."; - return false; - } - if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) { - return false; - } else if (!(base_type->IsGeneric())) { - return *(base_type) == *(x); - } else if (base_type->type_id() == x->type_id()) { - return true; - } else if (base_type->type_id() == x->generic_type_id()) { - return true; - } else if (base_type->type_id() == x->object_type()) { - return true; - } else if (base_type->type_id() == x->meta_type()) { - return true; - } else { - return false; - } -} - -bool IsSubType(TypePtr const &t1, TypePtr const &t2) { - MS_EXCEPTION_IF_NULL(t1); - if (t1->type_id() == kTypeUnknown) { - return false; - } else if (t2 != nullptr) { - return IsIdentidityOrSubclass(t1, t2); - } else { - return true; - } -} - -REGISTER_PYBIND_DEFINE( - typing, ([](py::module *const m) { - auto m_sub = m->def_submodule("typing", "submodule for dtype"); - py::enum_(m_sub, "TypeId"); - (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); - (void)m_sub.def("load_type", &TypeIdToType, "load type"); - (void)m_sub.def( - "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); - (void)py::class_>(m_sub, "Type") - .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) - .def("__eq__", - [](const TypePtr &t1, const TypePtr &t2) { - if (t1 != nullptr && t2 != nullptr) { - return *t1 == *t2; - } - return false; - }) - .def("__hash__", &Type::hash) - .def("__str__", &Type::ToString) - .def("__repr__", &Type::ReprString) - .def("__deepcopy__", [](const TypePtr &t, py::dict) { - if (t == nullptr) { - return static_cast(nullptr); - } - return t->DeepCopy(); - }); - (void)py::class_>(m_sub, "Number").def(py::init()); - (void)py::class_>(m_sub, "Bool") - .def(py::init()) - .def(py::pickle( - [](const Bool &) { // __getstate__ - return py::make_tuple(); - }, - [](const py::tuple &) { // __setstate__ - return std::make_shared(); - })); - (void)py::class_>(m_sub, "Int") - .def(py::init()) - .def(py::init(), py::arg("nbits")) - .def(py::pickle( - [](const Int &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(t.nbits())); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - Int data(t[0].cast()); - return data; - })); - (void)py::class_>(m_sub, "UInt") - .def(py::init()) - .def(py::init(), py::arg("nbits")) - .def(py::pickle( - [](const UInt &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(t.nbits())); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - UInt data(t[0].cast()); - return data; - })); - (void)py::class_>(m_sub, "Float") - .def(py::init()) - .def(py::init(), py::arg("nbits")) - .def(py::pickle( - [](const Float &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(t.nbits())); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - Float data(t[0].cast()); - return data; - })); - (void)py::class_>(m_sub, "List") - .def(py::init()) - .def(py::init>(), py::arg("elements")); - (void)py::class_>(m_sub, "Tuple") - .def(py::init()) - .def(py::init>(), py::arg("elements")); - (void)py::class_>(m_sub, "TensorType") - .def(py::init()) - .def(py::init(), py::arg("element")) - .def("element_type", &TensorType::element) - .def(py::pickle( - [](const TensorType &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(static_cast(t.element()->type_id()))); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - TensorType data(TypeIdToType(TypeId(static_cast(t[0].cast())))); - return data; - })); - (void)py::class_>(m_sub, "Function") - .def(py::init()) - .def(py::init, TypePtr>(), py::arg("args"), py::arg("retval")); - (void)py::class_>(m_sub, "Class").def(py::init()); - (void)py::class_>(m_sub, "SymbolicKeyType").def(py::init()); - (void)py::class_>(m_sub, "EnvType").def(py::init()); - (void)py::class_>(m_sub, "TypeNone").def(py::init()); - (void)py::class_>(m_sub, "TypeType").def(py::init()); - (void)py::class_>(m_sub, "String").def(py::init()); - (void)py::class_>(m_sub, "RefKeyType").def(py::init()); - (void)py::class_>(m_sub, "RefType").def(py::init()); - (void)py::class_>(m_sub, "TypeAnything").def(py::init()); - })); - -const TypePtr kTypeExternal = std::make_shared(); -const TypePtr kTypeEnv = std::make_shared(); -const TypePtr kTypeType = std::make_shared(); -const TypePtr kTensorType = std::make_shared(); -const TypePtr kString = std::make_shared(); } // namespace mindspore diff --git a/mindspore/ccsrc/ir/dtype/container.cc b/mindspore/ccsrc/ir/dtype/container.cc index 3f8244c2e3..082cebc82d 100644 --- a/mindspore/ccsrc/ir/dtype/container.cc +++ b/mindspore/ccsrc/ir/dtype/container.cc @@ -19,9 +19,6 @@ #include #include #include "utils/log_adapter.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" namespace mindspore { static std::string DumpTypeVector(const std::vector &elements, bool is_dumptext) { diff --git a/mindspore/ccsrc/ir/dtype/number.cc b/mindspore/ccsrc/ir/dtype/number.cc index 44ac9e8e6a..8bbfcd7e14 100644 --- a/mindspore/ccsrc/ir/dtype/number.cc +++ b/mindspore/ccsrc/ir/dtype/number.cc @@ -19,9 +19,6 @@ #include #include #include "utils/log_adapter.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" namespace mindspore { bool Number::operator==(const Type &other) const { diff --git a/mindspore/ccsrc/ir/dtype/ref.cc b/mindspore/ccsrc/ir/dtype/ref.cc index 9fa7f6750b..1cb601f4ae 100644 --- a/mindspore/ccsrc/ir/dtype/ref.cc +++ b/mindspore/ccsrc/ir/dtype/ref.cc @@ -19,9 +19,6 @@ #include #include #include "utils/log_adapter.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" namespace mindspore { TypePtr RefType::DeepCopy() const { diff --git a/mindspore/ccsrc/ir/dtype/type.cc b/mindspore/ccsrc/ir/dtype/type.cc index 56954495df..402010759d 100644 --- a/mindspore/ccsrc/ir/dtype/type.cc +++ b/mindspore/ccsrc/ir/dtype/type.cc @@ -21,9 +21,8 @@ #include #include #include "utils/log_adapter.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" +#include "ir/dtype/number.h" +#include "utils/convert_utils.h" namespace mindspore { TypeId IntBitsToTypeId(const int nbits) { @@ -227,11 +226,6 @@ bool Type::operator==(const Value &other) const { } } -abstract::AbstractBasePtr Type::ToAbstract() { - abstract::AbstractBasePtr ptr = std::make_shared(shared_from_base()); - return ptr; -} - std::ostream &operator<<(std::ostream &os, const Type &type) { os << type.ToString(); return os; diff --git a/mindspore/ccsrc/ir/dtype/type_extends.cc b/mindspore/ccsrc/ir/dtype/type_extends.cc new file mode 100644 index 0000000000..a77a6a9cba --- /dev/null +++ b/mindspore/ccsrc/ir/dtype/type_extends.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/dtype/type.h" +#include "pipeline/static_analysis/abstract_value.h" + +namespace mindspore { +abstract::AbstractBasePtr Type::ToAbstract() { + auto ptr = std::make_shared(shared_from_base()); + return ptr; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/dtype_extends.cc b/mindspore/ccsrc/ir/dtype_extends.cc new file mode 100644 index 0000000000..20c3c401e1 --- /dev/null +++ b/mindspore/ccsrc/ir/dtype_extends.cc @@ -0,0 +1,484 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/dtype.h" +#include +#include +#include +#include "utils/log_adapter.h" +#include "pipeline/static_analysis/abstract_value.h" +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" + +namespace mindspore { +TypePtr TypeAnything::DeepCopy() const { return kAnyType; } + +std::size_t TypeHasher::operator()(TypePtr const &type) const { + MS_EXCEPTION_IF_NULL(type); + std::size_t hash = std::hash()(type->type_id()); + return hash; +} + +std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const { + std::size_t hash_sum = 0; + for (auto &type : type_list) { + auto type_id = static_cast(type->type_id()); + hash_sum = hash_combine(hash_sum, type_id); + } + return hash_sum; +} + +bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const { + MS_EXCEPTION_IF_NULL(t1); + MS_EXCEPTION_IF_NULL(t2); + return t1->type_id() == t2->type_id(); +} + +bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const { + if (lhs.size() != rhs.size()) { + return false; + } + std::size_t size = lhs.size(); + for (std::size_t i = 0; i < size; ++i) { + MS_EXCEPTION_IF_NULL(lhs[i]); + MS_EXCEPTION_IF_NULL(rhs[i]); + if (*lhs[i] != *rhs[i]) { + return false; + } + } + return true; +} + +TypePtr TypeIdToType(TypeId id) { + switch (id) { + case kNumberTypeFloat16: + return kFloat16; + case kNumberTypeFloat: + case kNumberTypeFloat32: + return kFloat32; + case kNumberTypeFloat64: + return kFloat64; + case kNumberTypeInt8: + return kInt8; + case kNumberTypeInt16: + return kInt16; + case kNumberTypeInt32: + return kInt32; + case kNumberTypeInt64: + return kInt64; + case kNumberTypeUInt8: + return kUInt8; + case kNumberTypeUInt16: + return kUInt16; + case kNumberTypeUInt32: + return kUInt32; + case kNumberTypeUInt64: + return kUInt64; + case kNumberTypeBool: + return kBool; + case kMetaTypeExternal: + return kTypeExternal; + case kMetaTypeAnything: + return kAnyType; + case kMetaTypeNone: + return kTypeNone; + case kObjectTypeEnvType: + return kTypeEnv; + case kObjectTypeRefKey: + return kRefKeyType; + case kObjectTypeRef: + return kRefType; + case kTypeUnknown: + return kTypeNone; + default: + MS_LOG(EXCEPTION) << "Not support the type: " << id; + } +} + +namespace { +template +TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) { + TypePtr type = nullptr; + if (type_name == num_type_name) { + type = std::make_shared(); + } else { + try { + if (num_type_name.size() >= type_name.size()) { + MS_LOG(EXCEPTION) << "Convert type is error, type_name(" << type_name << "), num_type_name(" << num_type_name + << ")"; + } + auto bits = std::stoi(type_name.substr(num_type_name.size())); + type = std::make_shared(bits); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << num_type_name << " convert from string error " << e.what(); + } + } + return type; +} + +std::vector StringToVectorOfType(const std::string &type_names) { + std::vector types; + if (type_names.length() == 0) { + return types; + } + std::string::size_type start = 0; + std::string::size_type end = type_names.find_first_of(','); + while (end != std::string::npos) { + types.push_back(StringToType(type_names.substr(start, end))); + // Skip ',' to find the next element. + start = end + 1; + end = type_names.find_first_of(',', start); + } + if (start >= type_names.size()) { + MS_LOG(EXCEPTION) << "Type name is empty string."; + } + types.push_back(StringToType(type_names.substr(start))); + return types; +} + +TypePtr TensorStrToType(const std::string &type_name) { + TypePtr type = nullptr; + if (type_name == "Tensor") { + type = std::make_shared(); + } else { + try { + auto start = type_name.find_first_of('[') + 1; + auto end = type_name.find_last_of(']'); + if (start >= type_name.size()) { + return nullptr; + } + auto element_str = type_name.substr(start, end - start); + auto element_type = StringToType(element_str); + if (element_type == nullptr) { + return nullptr; + } + type = std::make_shared(element_type); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); + } + } + + return type; +} + +TypePtr ListStrToType(const std::string &type_name) { + TypePtr type = nullptr; + if (type_name == "List") { + type = std::make_shared(); + } else { + try { + auto start = type_name.find_first_of('[') + 1; + auto end = type_name.find_last_of(']'); + if (start >= type_name.size()) { + return nullptr; + } + std::string element_strs = type_name.substr(start, end - start); + std::vector element_types = StringToVectorOfType(element_strs); + bool wrong = + std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); + if (wrong) { + return nullptr; + } + type = std::make_shared(element_types); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); + } + } + + return type; +} + +TypePtr TupleStrToType(const std::string &type_name) { + TypePtr type = nullptr; + if (type_name == "Tuple") { + type = std::make_shared(); + } else { + try { + size_t start = type_name.find_first_of('[') + 1; + size_t end = type_name.find_last_of(']'); + if (start >= type_name.size()) { + return nullptr; + } + std::string element_strs = type_name.substr(start, end - start); + std::vector element_types = StringToVectorOfType(element_strs); + bool wrong = + std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); + if (wrong) { + return nullptr; + } + type = std::make_shared(element_types); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); + } + } + return type; +} + +TypePtr FunctionStrToType(const std::string &type_name) { + TypePtr type = nullptr; + + if (type_name == "Function") { + type = std::make_shared(); + } else { + try { + // format: [(para1, para2, para3, ...) retval] + size_t start = type_name.find_first_of('[') + 1; + size_t end = type_name.find_last_of(']'); + if (start >= type_name.size()) { + return nullptr; + } + std::string str_all = type_name.substr(start, end - start); + size_t start_a = str_all.find_first_of('(') + 1; + size_t end_a = str_all.find_last_of(')'); + if (start_a >= str_all.size()) { + return nullptr; + } + std::string str_args = str_all.substr(start_a, end_a - start_a); + // bypass " " between ")" and retval + start = end_a + 2; + if (start >= str_all.size()) { + return nullptr; + } + std::string str_retval = str_all.substr(start); + + std::vector args_type = StringToVectorOfType(str_args); + TypePtr retval = StringToType(str_retval); + bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; }); + if (retval == nullptr || wrong) { + return nullptr; + } + type = std::make_shared(args_type, retval); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); + } + } + return type; +} +} // namespace + +TypePtr StringToType(const std::string &type_name) { + TypePtr type = nullptr; + if (type_name.compare("None") == 0) { + type = std::make_shared(); + } else if (type_name.compare("Ellipsis") == 0) { + type = std::make_shared(); + } else if (type_name.compare("TypeType") == 0) { + type = std::make_shared(); + } else if (type_name.compare("SymbolicKeyType") == 0) { + type = std::make_shared(); + } else if (type_name.compare("RefKeyType") == 0) { + type = std::make_shared(); + } else if (type_name.compare("EnvType") == 0) { + type = std::make_shared(); + } else if (type_name.compare("Number") == 0) { + type = std::make_shared(); + } else if (type_name.compare("Bool") == 0) { + type = std::make_shared(); + } else if (type_name.compare(0, strlen("Int"), "Int") == 0) { + type = StringToNumberType(type_name, "Int"); + } else if (type_name.compare(0, strlen("UInt"), "UInt") == 0) { + type = StringToNumberType(type_name, "UInt"); + } else if (type_name.compare(0, strlen("Float"), "Float") == 0) { + type = StringToNumberType(type_name, "Float"); + } else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) { + type = TensorStrToType(type_name); + } else if (type_name.compare(0, strlen("List"), "List") == 0) { + type = ListStrToType(type_name); + } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { + type = TupleStrToType(type_name); + } else if (type_name.compare("Slice") == 0) { + type = std::make_shared(); + } else if (type_name.compare("Dictionary") == 0) { + type = std::make_shared(); + } else if (type_name.compare("String") == 0) { + type = std::make_shared(); + } else if (type_name.compare("Problem") == 0) { + type = std::make_shared(); + } else if (type_name.compare(0, strlen("Function"), "Function") == 0) { + type = FunctionStrToType(type_name); + } else { + // - unsupported to convert + // Class + // SymbolicType + // JTagged + // Anything + // External + // Problem + MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!"; + } + return type; +} + +bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { + if (x == nullptr || base_type == nullptr) { + MS_LOG(ERROR) << "Type is nullptr."; + return false; + } + if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) { + return false; + } else if (!(base_type->IsGeneric())) { + return *(base_type) == *(x); + } else if (base_type->type_id() == x->type_id()) { + return true; + } else if (base_type->type_id() == x->generic_type_id()) { + return true; + } else if (base_type->type_id() == x->object_type()) { + return true; + } else if (base_type->type_id() == x->meta_type()) { + return true; + } else { + return false; + } +} + +bool IsSubType(TypePtr const &t1, TypePtr const &t2) { + MS_EXCEPTION_IF_NULL(t1); + if (t1->type_id() == kTypeUnknown) { + return false; + } else if (t2 != nullptr) { + return IsIdentidityOrSubclass(t1, t2); + } else { + return true; + } +} + +REGISTER_PYBIND_DEFINE( + typing, ([](py::module *const m) { + auto m_sub = m->def_submodule("typing", "submodule for dtype"); + py::enum_(m_sub, "TypeId"); + (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); + (void)m_sub.def("load_type", &TypeIdToType, "load type"); + (void)m_sub.def( + "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); + (void)py::class_>(m_sub, "Type") + .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) + .def("__eq__", + [](const TypePtr &t1, const TypePtr &t2) { + if (t1 != nullptr && t2 != nullptr) { + return *t1 == *t2; + } + return false; + }) + .def("__hash__", &Type::hash) + .def("__str__", &Type::ToString) + .def("__repr__", &Type::ReprString) + .def("__deepcopy__", [](const TypePtr &t, py::dict) { + if (t == nullptr) { + return static_cast(nullptr); + } + return t->DeepCopy(); + }); + (void)py::class_>(m_sub, "Number").def(py::init()); + (void)py::class_>(m_sub, "Bool") + .def(py::init()) + .def(py::pickle( + [](const Bool &) { // __getstate__ + return py::make_tuple(); + }, + [](const py::tuple &) { // __setstate__ + return std::make_shared(); + })); + (void)py::class_>(m_sub, "Int") + .def(py::init()) + .def(py::init(), py::arg("nbits")) + .def(py::pickle( + [](const Int &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(t.nbits())); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + Int data(t[0].cast()); + return data; + })); + (void)py::class_>(m_sub, "UInt") + .def(py::init()) + .def(py::init(), py::arg("nbits")) + .def(py::pickle( + [](const UInt &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(t.nbits())); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + UInt data(t[0].cast()); + return data; + })); + (void)py::class_>(m_sub, "Float") + .def(py::init()) + .def(py::init(), py::arg("nbits")) + .def(py::pickle( + [](const Float &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(t.nbits())); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + Float data(t[0].cast()); + return data; + })); + (void)py::class_>(m_sub, "List") + .def(py::init()) + .def(py::init>(), py::arg("elements")); + (void)py::class_>(m_sub, "Tuple") + .def(py::init()) + .def(py::init>(), py::arg("elements")); + (void)py::class_>(m_sub, "TensorType") + .def(py::init()) + .def(py::init(), py::arg("element")) + .def("element_type", &TensorType::element) + .def(py::pickle( + [](const TensorType &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(static_cast(t.element()->type_id()))); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + TensorType data(TypeIdToType(TypeId(static_cast(t[0].cast())))); + return data; + })); + (void)py::class_>(m_sub, "Function") + .def(py::init()) + .def(py::init, TypePtr>(), py::arg("args"), py::arg("retval")); + (void)py::class_>(m_sub, "Class").def(py::init()); + (void)py::class_>(m_sub, "SymbolicKeyType").def(py::init()); + (void)py::class_>(m_sub, "EnvType").def(py::init()); + (void)py::class_>(m_sub, "TypeNone").def(py::init()); + (void)py::class_>(m_sub, "TypeType").def(py::init()); + (void)py::class_>(m_sub, "String").def(py::init()); + (void)py::class_>(m_sub, "RefKeyType").def(py::init()); + (void)py::class_>(m_sub, "RefType").def(py::init()); + (void)py::class_>(m_sub, "TypeAnything").def(py::init()); + })); + +const TypePtr kTypeExternal = std::make_shared(); +const TypePtr kTypeEnv = std::make_shared(); +const TypePtr kTypeType = std::make_shared(); +const TypePtr kTensorType = std::make_shared(); +const TypePtr kString = std::make_shared(); +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph_cloner.cc b/mindspore/ccsrc/ir/func_graph_cloner.cc index c8012276f1..ab0a4fb19c 100644 --- a/mindspore/ccsrc/ir/func_graph_cloner.cc +++ b/mindspore/ccsrc/ir/func_graph_cloner.cc @@ -19,6 +19,7 @@ #include #include "ir/manager.h" +#include "ir/param_value_py.h" #include "operator/ops.h" #include "utils/log_adapter.h" #include "utils/profile.h" @@ -69,7 +70,9 @@ void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, new_param->set_abstract(old_param->abstract()); new_param->set_name(old_param->name()); if (old_param->has_default()) { - new_param->set_default_param(old_param->default_param()); + auto param_value = std::dynamic_pointer_cast(old_param->default_param()); + auto param_value_new = std::make_shared(param_value->value()); + new_param->set_default_param(param_value_new); } ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); new_param->set_scope(scope); @@ -248,7 +251,9 @@ void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) { if (node->isa()) { ParameterPtr old_param = dyn_cast(node); if (old_param->has_default()) { - param->set_default_param(old_param->default_param()); + auto param_value = std::dynamic_pointer_cast(old_param->default_param()); + auto param_value_new = std::make_shared(param_value->value()); + param->set_default_param(param_value_new); } param->set_name(old_param->name()); } diff --git a/mindspore/ccsrc/ir/func_graph_cloner.h b/mindspore/ccsrc/ir/func_graph_cloner.h index 218398347c..4279ddfa12 100644 --- a/mindspore/ccsrc/ir/func_graph_cloner.h +++ b/mindspore/ccsrc/ir/func_graph_cloner.h @@ -28,6 +28,7 @@ #include "ir/anf.h" #include "ir/func_graph.h" +#include "ir/manager.h" namespace mindspore { class Cloner; diff --git a/mindspore/ccsrc/ir/meta_func_graph.h b/mindspore/ccsrc/ir/meta_func_graph.h index 482b5f9025..533c66d40a 100644 --- a/mindspore/ccsrc/ir/meta_func_graph.h +++ b/mindspore/ccsrc/ir/meta_func_graph.h @@ -31,6 +31,7 @@ #include "ir/dtype.h" #include "ir/anf.h" #include "ir/func_graph.h" +#include "ir/signature.h" #include "pipeline/static_analysis/abstract_value.h" namespace py = pybind11; diff --git a/mindspore/ccsrc/ir/named.h b/mindspore/ccsrc/ir/named.h index 2d679c58b1..fbc7969cab 100644 --- a/mindspore/ccsrc/ir/named.h +++ b/mindspore/ccsrc/ir/named.h @@ -21,7 +21,6 @@ #include #include -#include "ir/base.h" #include "ir/anf.h" namespace mindspore { diff --git a/mindspore/ccsrc/ir/param_value_minnie.h b/mindspore/ccsrc/ir/param_value_minnie.h new file mode 100644 index 0000000000..9b648247e5 --- /dev/null +++ b/mindspore/ccsrc/ir/param_value_minnie.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_ +#define MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_ + +#include + +#include "ir/anf.h" + +namespace mindspore { +class ParamValueMinnie : public ParamValue { + public: + ParamValueMinnie() : tensor_addr_(nullptr), tensor_size_(0) {} + virtual ~ParamValueMinnie() = default; + + size_t tensor_size() const { return tensor_size_; } + void set_tensor_size(size_t size) { tensor_size_ = size; } + + void *tensor_addr() const { return tensor_addr_; } + void set_tensor_addr(void *addr) { tensor_addr_ = addr; } + + private: + void *tensor_addr_; + size_t tensor_size_; +}; + +using ParamValueMinniePtr = std::shared_ptr; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_ diff --git a/mindspore/ccsrc/ir/param_value_py.h b/mindspore/ccsrc/ir/param_value_py.h new file mode 100644 index 0000000000..087dffaf60 --- /dev/null +++ b/mindspore/ccsrc/ir/param_value_py.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_ +#define MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_ + +#include + +#include "ir/anf.h" +#include "pybind11/pybind11.h" + +namespace mindspore { +namespace py = pybind11; + +class ParamValuePy : public ParamValue { + public: + ParamValuePy() : value_(py::none()) {} + explicit ParamValuePy(py::object value) : value_(value) {} + virtual ~ParamValuePy() = default; + + py::object value() { return value_; } + void set_value(const py::object &obj) { value_ = obj; } + + private: + py::object value_; +}; + +using ParamValuePyPtr = std::shared_ptr; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_ diff --git a/mindspore/ccsrc/ir/scalar.h b/mindspore/ccsrc/ir/scalar.h index ab6c485540..e8e29fb2f9 100644 --- a/mindspore/ccsrc/ir/scalar.h +++ b/mindspore/ccsrc/ir/scalar.h @@ -17,20 +17,23 @@ #ifndef MINDSPORE_CCSRC_IR_SCALAR_H_ #define MINDSPORE_CCSRC_IR_SCALAR_H_ -namespace mindspore { -/* namespace to support inference engine */ - #include #include +#include #include #include #include #include #include #include + #include "ir/base.h" #include "ir/dtype.h" +#include "ir/dtype/number.h" +using std::fabs; + +namespace mindspore { class Scalar : public Value { public: Scalar() = default; diff --git a/mindspore/ccsrc/ir/value.cc b/mindspore/ccsrc/ir/value.cc index e386e1ffd2..4dc0550c3d 100644 --- a/mindspore/ccsrc/ir/value.cc +++ b/mindspore/ccsrc/ir/value.cc @@ -19,9 +19,7 @@ #include #include #include - -#include "pybind_api/api_register.h" -#include "pipeline/static_analysis/abstract_value.h" +#include "utils/convert_utils.h" namespace mindspore { const ValuePtr ValueSequeue::operator[](const std::size_t &dim) const { @@ -208,41 +206,6 @@ bool AnyValue::operator==(const Value &other) const { } } const ValuePtr kAnyValue = std::make_shared(); -using ContextPtr = abstract::AnalysisContextPtr; - -abstract::AbstractBasePtr Scalar::ToAbstract() { - return std::make_shared(shared_from_base()); -} - -abstract::AbstractBasePtr StringImm::ToAbstract() { - return std::make_shared(shared_from_base(), std::make_shared()); -} - -abstract::AbstractBasePtr RefKey::ToAbstract() { - auto refkey = std::make_shared(); - refkey->set_value(shared_from_base()); - return refkey; -} - -abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared(); } - -abstract::AbstractBasePtr ValueTuple::ToAbstract() { - abstract::AbstractBasePtrList a_list; - (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { - MS_EXCEPTION_IF_NULL(ele); - return ele->ToAbstract(); - }); - return std::make_shared(a_list); -} - -abstract::AbstractBasePtr ValueList::ToAbstract() { - abstract::AbstractBasePtrList a_list; - (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { - MS_EXCEPTION_IF_NULL(ele); - return ele->ToAbstract(); - }); - return std::make_shared(a_list); -} std::size_t ValueSlice::hash() const { MS_EXCEPTION_IF_NULL(start_); @@ -280,16 +243,6 @@ std::string ValueSlice::ToString() const { return buffer.str(); } -abstract::AbstractBasePtr ValueSlice::ToAbstract() { - MS_EXCEPTION_IF_NULL(start_); - MS_EXCEPTION_IF_NULL(stop_); - MS_EXCEPTION_IF_NULL(step_); - abstract::AbstractBasePtr start = start_->ToAbstract(); - abstract::AbstractBasePtr end = stop_->ToAbstract(); - abstract::AbstractBasePtr step = step_->ToAbstract(); - return std::make_shared(start, end, step); -} - std::size_t KeywordArg::hash() const { MS_EXCEPTION_IF_NULL(value_); return hash_combine({tid(), std::hash{}(key_), value_->hash()}); @@ -316,12 +269,6 @@ std::string KeywordArg::ToString() const { return buffer.str(); } -abstract::AbstractBasePtr KeywordArg::ToAbstract() { - MS_EXCEPTION_IF_NULL(value_); - abstract::AbstractBasePtr argument = value_->ToAbstract(); - return std::make_shared(key_, argument); -} - const ValuePtr ValueDictionary::operator[](const std::string &key) const { auto it = std::find_if(key_values_.begin(), key_values_.end(), [key](const std::pair &item) { return item.first == key; }); @@ -354,17 +301,4 @@ bool ValueDictionary::operator==(const ValueDictionary &other) const { } return true; } - -abstract::AbstractBasePtr ValueDictionary::ToAbstract() { - std::vector> kv; - (void)std::transform( - key_values_.begin(), key_values_.end(), std::back_inserter(kv), - [](const std::pair &item) { return std::make_pair(item.first, item.second->ToAbstract()); }); - return std::make_shared(kv); -} - -REGISTER_PYBIND_DEFINE( - RefKey, ([](const py::module *m) { - (void)py::class_>(*m, "RefKey").def(py::init(), py::arg("tag")); - })); } // namespace mindspore diff --git a/mindspore/ccsrc/ir/value.h b/mindspore/ccsrc/ir/value.h index 160eac7b5c..ea9bb47ffe 100644 --- a/mindspore/ccsrc/ir/value.h +++ b/mindspore/ccsrc/ir/value.h @@ -29,6 +29,7 @@ #include "ir/anf.h" #include "ir/dtype.h" #include "ir/scalar.h" +#include "ir/dtype/ref.h" #include "utils/hashing.h" #include "common/utils.h" diff --git a/mindspore/ccsrc/ir/value_extends.cc b/mindspore/ccsrc/ir/value_extends.cc new file mode 100644 index 0000000000..8eb34d0eeb --- /dev/null +++ b/mindspore/ccsrc/ir/value_extends.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/value.h" +#include +#include +#include +#include + +#include "pybind_api/api_register.h" +#include "pipeline/static_analysis/abstract_value.h" + +namespace mindspore { +using ContextPtr = abstract::AnalysisContextPtr; + +abstract::AbstractBasePtr Scalar::ToAbstract() { + return std::make_shared(shared_from_base()); +} + +abstract::AbstractBasePtr StringImm::ToAbstract() { + return std::make_shared(shared_from_base(), std::make_shared()); +} + +abstract::AbstractBasePtr RefKey::ToAbstract() { + auto refkey = std::make_shared(); + refkey->set_value(shared_from_base()); + return refkey; +} + +abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared(); } + +abstract::AbstractBasePtr ValueTuple::ToAbstract() { + abstract::AbstractBasePtrList a_list; + (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { + MS_EXCEPTION_IF_NULL(ele); + return ele->ToAbstract(); + }); + return std::make_shared(a_list); +} + +abstract::AbstractBasePtr ValueList::ToAbstract() { + abstract::AbstractBasePtrList a_list; + (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { + MS_EXCEPTION_IF_NULL(ele); + return ele->ToAbstract(); + }); + return std::make_shared(a_list); +} + +abstract::AbstractBasePtr ValueSlice::ToAbstract() { + MS_EXCEPTION_IF_NULL(start_); + MS_EXCEPTION_IF_NULL(stop_); + MS_EXCEPTION_IF_NULL(step_); + abstract::AbstractBasePtr start = start_->ToAbstract(); + abstract::AbstractBasePtr end = stop_->ToAbstract(); + abstract::AbstractBasePtr step = step_->ToAbstract(); + return std::make_shared(start, end, step); +} + +abstract::AbstractBasePtr KeywordArg::ToAbstract() { + MS_EXCEPTION_IF_NULL(value_); + abstract::AbstractBasePtr argument = value_->ToAbstract(); + return std::make_shared(key_, argument); +} + +abstract::AbstractBasePtr ValueDictionary::ToAbstract() { + std::vector> kv; + (void)std::transform( + key_values_.begin(), key_values_.end(), std::back_inserter(kv), + [](const std::pair &item) { return std::make_pair(item.first, item.second->ToAbstract()); }); + return std::make_shared(kv); +} + +REGISTER_PYBIND_DEFINE( + RefKey, ([](const py::module *m) { + (void)py::class_>(*m, "RefKey").def(py::init(), py::arg("tag")); + })); +} // namespace mindspore diff --git a/mindspore/ccsrc/onnx/onnx_exporter.cc b/mindspore/ccsrc/onnx/onnx_exporter.cc index d53d1f63ed..f6f4ec2f1f 100644 --- a/mindspore/ccsrc/onnx/onnx_exporter.cc +++ b/mindspore/ccsrc/onnx/onnx_exporter.cc @@ -26,6 +26,7 @@ #include "debug/anf_ir_utils.h" #include "proto/onnx.pb.h" #include "operator/ops.h" +#include "ir/param_value_py.h" namespace mindspore { enum OpMergeMode { @@ -424,7 +425,8 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphP initializer_proto->set_name(param_ptr->ToString()); SetTensorProtoInfo(param_ptr, initializer_proto); // set value for initializer - py::object obj = param_ptr->default_param(); + auto param_value = std::dynamic_pointer_cast(param_ptr->default_param()); + py::object obj = param_value->value(); py::object data = obj.attr("data"); if (py::isinstance(data)) { auto method = data.attr("asnumpy"); diff --git a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h index e21b81a557..e34d628b2b 100644 --- a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h +++ b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h @@ -18,9 +18,10 @@ #define MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GET_GRAPH_INFO_H_ #include "pybind11/stl.h" - +#include "pybind11/pybind11.h" #include "ir/anf.h" +namespace py = pybind11; namespace mindspore { namespace parallel { py::dict GetParameterLayout(const FuncGraphPtr &graph); diff --git a/mindspore/ccsrc/parallel/graph_util/node_info.cc b/mindspore/ccsrc/parallel/graph_util/node_info.cc index c085d71240..7298b06832 100644 --- a/mindspore/ccsrc/parallel/graph_util/node_info.cc +++ b/mindspore/ccsrc/parallel/graph_util/node_info.cc @@ -19,6 +19,7 @@ #include #include "ir/anf.h" +#include "ir/param_value_py.h" #include "pipeline/parse/python_adapter.h" namespace mindspore { @@ -37,7 +38,8 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { if (!para_ptr->has_default()) { return false; } - return py::cast(parse::python_adapter::GetPyObjAttr(para_ptr->default_param(), "requires_grad")); + auto param_value = std::dynamic_pointer_cast(para_ptr->default_param()); + return py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad")); } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index f7a18c8b59..3811efdd6a 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -28,6 +28,7 @@ #include #include "ir/anf.h" +#include "ir/param_value_py.h" #include "ir/meta_tensor.h" #include "optimizer/opt.h" #include "optimizer/optimizer.h" @@ -190,8 +191,8 @@ std::vector ExtractInputParameterByNode(const CNodePtr &node) { if (input->isa()) { auto input_parameter = input->cast(); if (input_parameter->has_default()) { - bool require_grad = - py::cast(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad")); + auto param_value = std::dynamic_pointer_cast(input_parameter->default_param()); + bool require_grad = py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad")); is_parameter.push_back(require_grad); } else { is_parameter.push_back(false); @@ -835,8 +836,8 @@ void AugmentCostGraph(const std::vector &all_nodes) { auto casted_target_parameter = target_parameter->cast(); MS_EXCEPTION_IF_NULL(casted_target_parameter); if (casted_target_parameter->has_default()) { - bool require_grad = py::cast( - parse::python_adapter::GetPyObjAttr(casted_target_parameter->default_param(), "requires_grad")); + auto param_value = std::dynamic_pointer_cast(casted_target_parameter->default_param()); + bool require_grad = py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad")); is_parameter.push_back(require_grad); } else { is_parameter.push_back(false); diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 291d628571..6c3b51347f 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -28,6 +28,7 @@ #include #include "ir/meta_tensor.h" +#include "ir/param_value_py.h" #include "operator/ops.h" #include "optimizer/optimizer.h" #include "parallel/auto_parallel/graph_costmodel.h" @@ -1292,7 +1293,8 @@ bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr ¶meter_nod return false; } - py::object clone_info = parse::python_adapter::GetPyObjAttr(cloned_parameter->default_param(), CLONE_INFO); + auto param_value = std::dynamic_pointer_cast(cloned_parameter->default_param()); + py::object clone_info = parse::python_adapter::GetPyObjAttr(param_value->value(), CLONE_INFO); bool cloned = py::cast(parse::python_adapter::GetPyObjAttr(clone_info, CLONED)); if (!cloned) { return false; @@ -1314,7 +1316,8 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { } // get the cloned index - py::object cloned_info = parse::python_adapter::GetPyObjAttr(cloned_parameter->default_param(), CLONE_INFO); + auto param_value = std::dynamic_pointer_cast(cloned_parameter->default_param()); + py::object cloned_info = parse::python_adapter::GetPyObjAttr(param_value->value(), CLONE_INFO); int32_t cloned_index = py::cast(parse::python_adapter::GetPyObjAttr(cloned_info, CLONED_INDEX)); // find the be cloned parameter @@ -1329,7 +1332,8 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { continue; } - py::object be_cloned_info = parse::python_adapter::GetPyObjAttr(be_cloned_parameter->default_param(), CLONE_INFO); + auto param_value_cloned = std::dynamic_pointer_cast(be_cloned_parameter->default_param()); + py::object be_cloned_info = parse::python_adapter::GetPyObjAttr(param_value_cloned->value(), CLONE_INFO); if (!py::cast(parse::python_adapter::GetPyObjAttr(be_cloned_info, BE_CLONED))) { continue; } @@ -2072,9 +2076,9 @@ std::string NodeParameterName(const CNodePtr &node) { if (input->isa()) { auto input_parameter = input->cast(); if (input_parameter->has_default()) { - if (py::cast(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), REQUIRES_GRAD))) { - return py::cast( - parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), PARAM_NAME)); + auto param_value = std::dynamic_pointer_cast(input_parameter->default_param()); + if (py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), REQUIRES_GRAD))) { + return py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), PARAM_NAME)); } } } diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index 5c8edd7c86..59a10afad4 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -24,6 +24,7 @@ #include #include "ir/func_graph_cloner.h" +#include "ir/param_value_py.h" #include "parallel/costmodel_context.h" #include "parallel/context.h" #include "pipeline/pass.h" @@ -225,8 +226,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { for (const auto ¶m : func_graph->parameters()) { auto param_node = std::static_pointer_cast(param); if (param_node->has_default()) { - AbstractBasePtr ptr = - abstract::FromValue(parse::data_converter::PyDataToValue(param_node->default_param()), true); + auto param_value = std::dynamic_pointer_cast(param_node->default_param()); + AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true); parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); args_spec.push_back(ptr); diff --git a/mindspore/ccsrc/pipeline/parse/function_block.cc b/mindspore/ccsrc/pipeline/parse/function_block.cc index 16b0dfe30e..24e7ae74fb 100644 --- a/mindspore/ccsrc/pipeline/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/parse/function_block.cc @@ -25,8 +25,11 @@ #include "operator/ops.h" #include "debug/info.h" #include "debug/trace.h" +#include "pybind11/pybind11.h" namespace mindspore { +namespace py = pybind11; + namespace parse { FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) { func_graph_ = std::make_shared(); diff --git a/mindspore/ccsrc/pipeline/parse/parse_base.h b/mindspore/ccsrc/pipeline/parse/parse_base.h index c7ce4e1196..ef1aeef55c 100644 --- a/mindspore/ccsrc/pipeline/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/parse/parse_base.h @@ -18,11 +18,13 @@ #define PIPELINE_PARSE_PARSE_BASE_H_ #include #include +#include "pybind11/pybind11.h" #include "ir/anf.h" #include "ir/func_graph.h" #include "ir/manager.h" #include "pybind_api/export_flags.h" +namespace py = pybind11; namespace mindspore { namespace parse { // define the node type diff --git a/mindspore/ccsrc/pipeline/parse/resolve.cc b/mindspore/ccsrc/pipeline/parse/resolve.cc index 576f63a1cf..d5e1f828cc 100644 --- a/mindspore/ccsrc/pipeline/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/parse/resolve.cc @@ -21,6 +21,7 @@ #include #include +#include "ir/param_value_py.h" #include "pipeline/parse/data_converter.h" #include "pipeline/parse/parse.h" #include "pipeline/parse/python_adapter.h" @@ -101,8 +102,9 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object } } if (para_node == nullptr) { - ParameterPtr node = top_graph->AddWeightParameter(param_name); - node->set_default_param(obj); + auto node = top_graph->AddWeightParameter(param_name); + auto param_value_new = std::make_shared(obj); + node->set_default_param(param_value_new); // set_abstract for parameter auto to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index 2d1dafbb5f..6799a6bd77 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -24,6 +24,7 @@ #include #include +#include "ir/param_value_py.h" #include "pipeline/pass.h" #include "pipeline/parse/data_converter.h" #include "optimizer/ad/dfunctor.h" @@ -619,7 +620,12 @@ void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, V // maybe some default parameter for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) { MS_EXCEPTION_IF_NULL(graph_params[i]); - py::object obj = dyn_cast(graph_params[i])->default_param(); + auto param_ptr = (graph_params[i])->cast(); + if (!param_ptr->has_default()) { + MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param"; + } + auto param_value = std::dynamic_pointer_cast(param_ptr->default_param()); + py::object obj = param_value->value(); py::object p_value = py::cast(parse::python_adapter::GetPyObjAttr(obj, "default_input")); (*arg_list).push_back(p_value); } diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc old mode 100755 new mode 100644 index 9146bdc622..aebf738419 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -20,6 +20,7 @@ #include #include "common/utils.h" #include "operator/ops.h" +#include "ir/param_value_py.h" #include "session/anf_runtime_algorithm.h" #include "device/kernel_info.h" #include "kernel/kernel_build_info.h" @@ -232,7 +233,9 @@ ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { new_parameter->set_abstract(parameter->abstract()); new_parameter->set_name(parameter->name()); if (AnfAlgo::IsParameterWeight(parameter)) { - new_parameter->set_default_param(parameter->default_param()); + auto param_value = std::dynamic_pointer_cast(parameter->default_param()); + auto param_value_new = std::make_shared(param_value->value()); + new_parameter->set_default_param(param_value_new); kernel_info->SetFeatureMapFlag(false); } else { kernel_info->SetFeatureMapFlag(true); diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index 59ee4cc870..53e12ea69d 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -20,6 +20,7 @@ #include #include "pipeline/parse/data_converter.h" #include "ir/manager.h" +#include "ir/param_value_py.h" #include "operator/ops.h" #include "common/trans.h" #include "utils/context/ms_context.h" @@ -44,10 +45,11 @@ tensor::TensorPtr GetParamDefaultInputTensor(const AnfNodePtr &node) { return nullptr; } auto parameter = node->cast(); - if (parameter == nullptr) { + if (parameter == nullptr || !parameter->has_default()) { return nullptr; } - auto py_param = parameter->default_param(); + auto param_value = std::dynamic_pointer_cast(parameter->default_param()); + auto py_param = param_value->value(); if (!py::hasattr(py_param, "default_input")) { return nullptr; } @@ -315,7 +317,8 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr &graph, MS_EXCEPTION_IF_NULL(param); if (tensor_mask == 1) { py::object obj; - param->set_default_param(obj); + auto param_value_new = std::make_shared(obj); + param->set_default_param(param_value_new); } // set the kernel info of parameter auto kernel_build_info_builder = std::make_shared(); diff --git a/mindspore/ccsrc/utils/base_ref.h b/mindspore/ccsrc/utils/base_ref.h index 74ccff8f80..e55cd39357 100644 --- a/mindspore/ccsrc/utils/base_ref.h +++ b/mindspore/ccsrc/utils/base_ref.h @@ -25,9 +25,11 @@ #include #include #include - +#include "pybind11/pybind11.h" #include "ir/value.h" +namespace py = pybind11; + namespace mindspore { class BaseRef; class VectorRef; diff --git a/mindspore/ccsrc/utils/callbacks_ge.cc b/mindspore/ccsrc/utils/callbacks_ge.cc index da817b3f78..1f11ac4d0d 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.cc +++ b/mindspore/ccsrc/utils/callbacks_ge.cc @@ -16,6 +16,7 @@ #include "utils/callbacks_ge.h" #include "pybind11/pybind11.h" +#include "ir/param_value_py.h" #include "transform/df_graph_manager.h" #include "transform/util.h" #include "pipeline/parse/data_converter.h" @@ -49,7 +50,11 @@ bool GetParameterShape(const FuncGraphPtr &graph, const std::string ¶m_name, return false; } if (param_node->name() == param_name) { - py::object parameter = param_node->default_param(); + py::object parameter; + if (param_node->has_default()) { + auto param_value = std::dynamic_pointer_cast(param_node->default_param()); + parameter = param_value->value(); + } ValuePtr value = parse::data_converter::PyDataToValue(parameter); TensorPtr tensor = std::dynamic_pointer_cast(value); if (tensor == nullptr) { diff --git a/tests/ut/cpp/common/py_func_graph_fetcher.h b/tests/ut/cpp/common/py_func_graph_fetcher.h index 3c9de9f971..98552a96b5 100644 --- a/tests/ut/cpp/common/py_func_graph_fetcher.h +++ b/tests/ut/cpp/common/py_func_graph_fetcher.h @@ -19,7 +19,9 @@ #include #include #include "ir/anf.h" +#include "ir/primitive.h" #include "ir/manager.h" +#include "ir/func_graph.h" #include "pipeline/parse/parse_base.h" #include "pipeline/parse/parse.h" #include "./common.h" diff --git a/tests/ut/cpp/ir/dtype_test.cc b/tests/ut/cpp/ir/dtype_test.cc index 5bfdfb3af6..d15abf395c 100644 --- a/tests/ut/cpp/ir/dtype_test.cc +++ b/tests/ut/cpp/ir/dtype_test.cc @@ -16,6 +16,10 @@ #include #include "common/common_test.h" #include "ir/dtype.h" +#include "ir/dtype/ref.h" +#include "ir/dtype/number.h" +#include "ir/dtype/container.h" +#include "ir/dtype/empty.h" namespace mindspore { class TestDType : public UT::Common { diff --git a/tests/ut/cpp/ir/meta_tensor_test.cc b/tests/ut/cpp/ir/meta_tensor_test.cc index f149d5d154..310b6ebb2d 100644 --- a/tests/ut/cpp/ir/meta_tensor_test.cc +++ b/tests/ut/cpp/ir/meta_tensor_test.cc @@ -92,21 +92,22 @@ class TestTensor : public UT::Common { TestTensor() {} virtual void SetUp() { UT::InitPythonPath(); - // Init tensor data by py::array_t - input_ = py::array_t({2, 3}); - auto array = input_.mutable_unchecked(); - float start = 0; - for (int i = 0; i < array.shape(0); i++) { - for (int j = 0; j < array.shape(1); j++) { - array(i, j) = start++; - } - } } - - protected: - py::array_t input_; }; +py::array_t BuildInputTensor() { + // Init tensor data by py::array_t + py::array_t input = py::array_t({2, 3}); + auto array = input.mutable_unchecked(); + float start = 0; + for (int i = 0; i < array.shape(0); i++) { + for (int j = 0; j < array.shape(1); j++) { + array(i, j) = start++; + } + } + return input; +} + TEST_F(TestTensor, PyArrayScalarTest) { std::vector dimensions; py::array data = py::array_t(dimensions); @@ -246,7 +247,7 @@ TEST_F(TestTensor, PyArrayTest) { TEST_F(TestTensor, InitByFloatArrayDataCTest) { // Init tensor data by py::array_t - TensorPtr tensor = std::make_shared(input_); + auto tensor = std::make_shared(BuildInputTensor()); // Print some information of the tensor std::cout << "Datatype: " << tensor->data_type() << std::endl; @@ -268,7 +269,7 @@ TEST_F(TestTensor, InitByFloatArrayDataCTest) { TEST_F(TestTensor, InitByFloatArrayDataTest) { // Init tensor data by py::array_t - TensorPtr tensor = std::make_shared(input_); + TensorPtr tensor = std::make_shared(BuildInputTensor()); // Print some information of the tensor std::cout << "Datatype: " << tensor->data_type() << std::endl; diff --git a/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc index 423a258a28..291539c27d 100644 --- a/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc @@ -15,6 +15,7 @@ */ #include "common/common_test.h" +#include "ir/dtype/number.h" #include "parallel/device_manager.h" #include "parallel/auto_parallel/edge_costmodel.h" #include "parallel/ops_info/matmul_info.h" diff --git a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc index 93147c486b..6f5c1e49ed 100644 --- a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "parallel/tensor_layout/util_layout_gen_test.h" +#include #include #include #include @@ -23,6 +24,8 @@ #include "parallel/tensor_layout/shape_util.h" #include "common/common_test.h" +using std::pow; + namespace mindspore { namespace parallel { std::vector> combine(const std::vector& in, int32_t target) { diff --git a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc index 6375d1a758..9ff8123004 100644 --- a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc +++ b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc @@ -15,6 +15,7 @@ */ #include "common/common_test.h" +#include "ir/param_value_py.h" #include "operator/ops.h" #include "session/kernel_graph.h" #include "session/anf_runtime_algorithm.h" @@ -765,7 +766,8 @@ TEST_F(AnfRuntimeAlgorithmTest, IsParameterWeight) { py::object obj; auto parameter_node = kernel_graph->add_parameter(); MS_EXCEPTION_IF_NULL(parameter_node); - parameter_node->set_default_param(obj); + auto param_value_new = std::make_shared(obj); + parameter_node->set_default_param(param_value_new); EXPECT_TRUE(AnfAlgo::IsParameterWeight(parameter_node)); EXPECT_THROW(AnfAlgo::IsParameterWeight(nullptr), std::runtime_error); } diff --git a/tests/ut/cpp/session/kernel_graph_test.cc b/tests/ut/cpp/session/kernel_graph_test.cc index a62af9c892..75e653c26c 100644 --- a/tests/ut/cpp/session/kernel_graph_test.cc +++ b/tests/ut/cpp/session/kernel_graph_test.cc @@ -15,6 +15,7 @@ */ #include "common/common_test.h" +#include "ir/param_value_py.h" #include "operator/ops.h" #include "session/kernel_graph.h" #include "session/anf_runtime_algorithm.h" @@ -82,7 +83,8 @@ TEST_F(KernelGraphTest, NewParameter) { auto weight_parameter_node = anf_graph->add_parameter(); MS_EXCEPTION_IF_NULL(weight_parameter_node); py::object obj; - weight_parameter_node->set_default_param(obj); + auto param_value_new = std::make_shared(obj); + weight_parameter_node->set_default_param(param_value_new); weight_parameter_node->set_abstract(x_abstract); auto new_weight_parameter_node = kernel_graph->NewParameter(weight_parameter_node); EXPECT_NE(new_weight_parameter_node, nullptr);