Browse Source

clean code

tags/v1.6.0
huangbingjian 4 years ago
parent
commit
e623173965
8 changed files with 84 additions and 82 deletions
  1. +13
    -9
      mindspore/ccsrc/debug/dump_proto.cc
  2. +36
    -27
      mindspore/ccsrc/frontend/operator/composite/do_signature.cc
  3. +0
    -1
      mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc
  4. +1
    -2
      mindspore/ccsrc/frontend/optimizer/irpass/inline.h
  5. +23
    -16
      mindspore/core/load_mindir/anf_model_parser.cc
  6. +1
    -0
      mindspore/core/load_mindir/anf_model_parser.h
  7. +8
    -24
      mindspore/core/utils/check_convert_utils.cc
  8. +2
    -3
      mindspore/core/utils/check_convert_utils.h

+ 13
- 9
mindspore/ccsrc/debug/dump_proto.cc View File

@@ -124,6 +124,18 @@ void CheckIfValidType(const TypePtr &type) {
}
}

void SetTensorType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) {
TypePtr elem_type = dyn_cast<TensorType>(type)->element();
type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type));
type_proto->set_data_type(irpb::DT_TENSOR);
if (shape != nullptr && shape->isa<abstract::Shape>()) {
abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
for (const auto &elem : shape_info->shape()) {
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
}
}
}

void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) {
if (type_proto == nullptr) {
return;
@@ -136,15 +148,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
if (type->isa<Number>()) {
type_proto->set_data_type(GetNumberDataType(type));
} else if (type->isa<TensorType>()) {
TypePtr elem_type = dyn_cast<TensorType>(type)->element();
type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type));
type_proto->set_data_type(irpb::DT_TENSOR);
if (shape != nullptr && shape->isa<abstract::Shape>()) {
abstract::ShapePtr shape_info = dyn_cast<abstract::Shape>(shape);
for (const auto &elem : shape_info->shape()) {
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
}
}
SetTensorType(type, shape, type_proto);
} else if (type->isa<Tuple>()) {
TuplePtr tuple_type = dyn_cast<Tuple>(type);
type_proto->set_data_type(irpb::DT_TUPLE);


+ 36
- 27
mindspore/ccsrc/frontend/operator/composite/do_signature.cc View File

@@ -91,6 +91,28 @@ bool GetTensorOrScalarTypeInfo(const TypePtr &arg_type_origin, TypeId *arg_type_
return false;
}

TypeId GetMaxTypeIdForNumber(TypeId max_type_id, bool has_int8, bool has_scalar_int64, bool has_scalar_float32) {
if (max_type_id == kNumberTypeUInt8 && has_int8) {
max_type_id = kNumberTypeInt16;
}
// if bool is the max type, see if there is scalar input
// if so, it means that max is bool tensor, use scalar type instead.
// for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2])
if (max_type_id == kNumberTypeBool) {
if (has_scalar_int64) {
max_type_id = kNumberTypeInt64;
}
if (has_scalar_float32) {
max_type_id = kNumberTypeFloat32;
}
}
if (max_type_id != kNumberTypeFloat16 && max_type_id != kNumberTypeFloat32 && max_type_id != kNumberTypeFloat64 &&
max_type_id != kTypeUnknown && has_scalar_float32) {
max_type_id = kNumberTypeFloat32;
}
return max_type_id;
}

TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, const std::vector<size_t> &indices) {
TypeId max_type_id = kTypeUnknown;
size_t max_type_number = 0;
@@ -126,26 +148,7 @@ TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, const std::vector<s
SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second);
}
}

if (max_type_id == kNumberTypeUInt8 && has_int8) {
max_type_id = kNumberTypeInt16;
}
// if bool is the max type, see if there is scalar input
// if so, it means that max is bool tensor, use scalar type instead.
// for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2])
if (max_type_id == kNumberTypeBool) {
if (has_scalar_int64) {
max_type_id = kNumberTypeInt64;
}
if (has_scalar_float32) {
max_type_id = kNumberTypeFloat32;
}
}
if (max_type_id != kNumberTypeFloat16 && max_type_id != kNumberTypeFloat32 && max_type_id != kNumberTypeFloat64 &&
max_type_id != kTypeUnknown && has_scalar_float32) {
max_type_id = kNumberTypeFloat32;
}
return max_type_id;
return GetMaxTypeIdForNumber(max_type_id, has_int8, has_scalar_int64, has_scalar_float32);
}

// Get the largest type of index in the same SignatureEnumDType of arguments.
@@ -257,6 +260,18 @@ void CheckSigSize(const size_t &sig_size, const bool &has_var, const AbstractBas
}
}

SignatureEnumRW GetSignatureEnumRW(size_t index, const std::vector<Signature> &signature, bool has_var) {
SignatureEnumRW sig = SignatureEnumRW::kRWDefault;
// If sig_size is 0 use default.
std::size_t sig_size = signature.size();
if (index < sig_size) {
sig = signature[index].rw;
} else if (has_var && index >= sig_size) {
sig = signature[sig_size - 1].rw;
}
return sig;
}

AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> &params_list) {
// args: original inputs
@@ -277,14 +292,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
op_inputs.push_back(param);
continue;
}
SignatureEnumRW sig = SignatureEnumRW::kRWDefault;
// If sig_size is 0 use default.
if (sig_size > 0 && i < sig_size) {
sig = signature[i].rw;
} else if (has_var && i >= sig_size) {
sig = signature[sig_size - 1].rw;
}

SignatureEnumRW sig = GetSignatureEnumRW(i, signature, has_var);
TypePtr type = args_spec_list[i]->BuildType();
if (type && type->isa<RefType>()) {
if (sig == SignatureEnumRW::kRWRead) {


+ 0
- 1
mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc View File

@@ -31,7 +31,6 @@
namespace mindspore {
namespace opt {
namespace {

using ParamUserMap = std::unordered_map<std::string, std::vector<size_t>>;
using LoadGraphMap = OrderedMap<std::string, std::vector<size_t>>;



+ 1
- 2
mindspore/ccsrc/frontend/optimizer/irpass/inline.h View File

@@ -146,9 +146,8 @@ class InlinerBase : public AnfVisitor {
if (IsForceInline(this, fg, node)) {
if (IsUniqueUse(nullptr, fg, nullptr)) {
return InlineMove(node, fg, args, inputs);
} else {
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
}
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
}

if (IsUniqueUse(nullptr, fg, nullptr)) {


+ 23
- 16
mindspore/core/load_mindir/anf_model_parser.cc View File

@@ -960,6 +960,27 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr
return std::make_shared<ValueNode>(prim);
}

bool MSANFModelParser::CheckCNodePrim(CNodePtr cnode_ptr) {
// Handle control flow operator.
auto operatorPtr = cnode_ptr->input(0);
// Set abstract of switch(c,f,t),switchLayer(c,tup) and
// partial(func,args) to null
auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) ||
IsPrimitiveEquals(prim::kPrimPartial, prim)) {
cnode_ptr->set_abstract(nullptr);
return true;
}

// If the operator is not a primitive, the abstract will been set to null.
// Because there are not some operators in front end, the abstract of primitive should be reserved.
if (prim == nullptr) {
cnode_ptr->set_abstract(nullptr);
return true;
}
return false;
}

void MSANFModelParser::SetEmptyTensorProtoCNodeAbstract(CNodePtr cnode_ptr, const std::string &node_type) {
if (node_type == "UpdateState") {
cnode_ptr->set_abstract(kUMonad->ToAbstract());
@@ -984,22 +1005,7 @@ void MSANFModelParser::SetEmptyTensorProtoCNodeAbstract(CNodePtr cnode_ptr, cons

// Set CNode abstract.
void MSANFModelParser::SetCNodeAbstract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr) {
const std::string &node_type = node_proto.op_type();
// Handle control flow operator.
auto operatorPtr = cnode_ptr->input(0);
// Set abstract of switch(c,f,t),switchLayer(c,tup) and
// partial(func,args) to null
auto prim = GetValueNode<PrimitivePtr>(operatorPtr);
if (IsPrimitiveEquals(prim::kPrimSwitch, prim) || IsPrimitiveEquals(prim::kPrimSwitchLayer, prim) ||
IsPrimitiveEquals(prim::kPrimPartial, prim)) {
cnode_ptr->set_abstract(nullptr);
return;
}

// If the operator is not a primitive, the abstract will been set to null.
// Because there are not some operators in front end, the abstract of primitive should be reserved.
if (prim == nullptr && need_renormalize()) {
cnode_ptr->set_abstract(nullptr);
if (CheckCNodePrim(cnode_ptr)) {
return;
}

@@ -1020,6 +1026,7 @@ void MSANFModelParser::SetCNodeAbstract(const mind_ir::NodeProto &node_proto, CN

// Because there is not context in unit test,
// abstract->broaden() is replaced by abstract->set_value(kAnyValue).
const std::string &node_type = node_proto.op_type();
if (kv.size() == 0) {
SetEmptyTensorProtoCNodeAbstract(cnode_ptr, node_type);
} else if (kv.size() == 1 && !is_tuple_or_list) {


+ 1
- 0
mindspore/core/load_mindir/anf_model_parser.h View File

@@ -71,6 +71,7 @@ class MSANFModelParser {
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto);
bool CheckCNodePrim(CNodePtr cnode_ptr);
void SetEmptyTensorProtoCNodeAbstract(CNodePtr cnode_ptr, const std::string &node_type);
void SetCNodeAbstract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr);
bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);


+ 8
- 24
mindspore/core/utils/check_convert_utils.cc View File

@@ -571,7 +571,8 @@ ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_na
}

TypePtr CheckAndConvertUtils::CheckTensorSubClass(const string &type_name, const TypePtr &type,
const std::set<TypePtr> &template_types, const string &prim_name) {
const std::set<TypePtr> &template_types, const string &prim_name,
bool is_mix) {
if (CheckType(type, template_types)) {
return type;
}
@@ -584,6 +585,11 @@ TypePtr CheckAndConvertUtils::CheckTensorSubClass(const string &type_name, const
}
buffer << " Tensor[" << item->ToString() << "],";
}
if (is_mix) {
for (const auto &item : template_types) {
buffer << " " << item->ToString() << "],";
}
}
buffer << "}, but got " << type->ToString();
buffer << ".";
MS_EXCEPTION(TypeError) << buffer.str();
@@ -613,7 +619,7 @@ TypePtr CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::s
(void)input_names.append(item.first);
(void)input_names.append(", ");
}
return CheckMixSubClass(input_names, arg_, valid_values, prim_name);
return CheckTensorSubClass(input_names, arg_, valid_values, prim_name, true);
}

TypePtr CheckAndConvertUtils::_CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,
@@ -809,26 +815,4 @@ bool CheckAndConvertUtils::HasDynamicShapeInput(const AbstractBasePtrList &abs_l
}
return false;
}

TypePtr CheckAndConvertUtils::CheckMixSubClass(const string &type_name, const TypePtr &type,
const std::set<TypePtr> &template_types, const string &prim_name) {
if (CheckType(type, template_types)) {
return type;
}
std::ostringstream buffer;
buffer << "Primitive[" << prim_name << "]'s input argument[" << type_name << "] must be a type of {";
for (const auto &item : template_types) {
if (item->isa<TensorType>()) {
buffer << item->ToString();
continue;
}
buffer << " Tensor[" << item->ToString() << "],";
}
for (const auto &item : template_types) {
buffer << " " << item->ToString() << "],";
}
buffer << "}, but got " << type->ToString();
buffer << ".";
MS_EXCEPTION(TypeError) << buffer.str();
}
} // namespace mindspore

+ 2
- 3
mindspore/core/utils/check_convert_utils.h View File

@@ -322,9 +322,8 @@ class CheckAndConvertUtils {
static TypePtr _CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,
const bool allow_mix);
static TypePtr CheckTensorSubClass(const std::string &type_name, const TypePtr &type,
const std::set<TypePtr> &template_types, const std::string &prim_name);
static TypePtr CheckMixSubClass(const std::string &type_name, const TypePtr &type,
const std::set<TypePtr> &template_types, const std::string &prim_name);
const std::set<TypePtr> &template_types, const std::string &prim_name,
bool is_mix = false);
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_

Loading…
Cancel
Save