| @@ -69,7 +69,7 @@ void PrintNodeOutputType(std::ostringstream &buffer, const AnfNodePtr &nd) { | |||
| TypePtr type = dyn_cast<Type>(nd->Type()); | |||
| if ((shape != nullptr) && (type != nullptr)) { | |||
| buffer << "<" << type << "x" << shape->shape() << ">"; | |||
| } else if (nullptr != type) { | |||
| } else if (type != nullptr) { | |||
| buffer << "<" << type << ">"; | |||
| } else { | |||
| buffer << "<null>"; | |||
| @@ -38,6 +38,8 @@ using mindspore::abstract::AbstractSparseTensor; | |||
| using mindspore::abstract::AbstractTuple; | |||
| using mindspore::abstract::AbstractUndetermined; | |||
| static constexpr size_t kDictInputSize = 2; | |||
| static AbstractBasePtr Reabs(const AbstractBasePtr &t) { | |||
| if (t == nullptr) { | |||
| return nullptr; | |||
| @@ -307,7 +309,7 @@ AnfNodePtr EraseMakeDictNode(const CNodePtr &node) { | |||
| AnfNodePtr EraseDictGetValues(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| const auto &inputs = node->inputs(); | |||
| MS_ASSERT(inputs.size() == 2 && "DictGetValues should have two inputs"); | |||
| MS_ASSERT(inputs.size() == kDictInputSize && "DictGetValues should have two inputs"); | |||
| return inputs[1]; | |||
| } | |||
| @@ -40,6 +40,10 @@ using MetaTensorPtr = mindspore::tensor::MetaTensorPtr; | |||
| using InstanceCheckFunc = std::function<bool(const py::object &)>; | |||
| using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>; | |||
| static constexpr int kBit8 = 8; | |||
| static constexpr int kBit16 = 16; | |||
| static constexpr int kBit32 = 32; | |||
| static constexpr int kBit64 = 64; | |||
| class DataConverter { | |||
| public: | |||
| explicit DataConverter(InstanceConvertFunc convert_func) : convert_func_(std::move(convert_func)) {} | |||
| @@ -362,16 +366,16 @@ ValuePtr ConvertNumberWithType(const T &obj, TypePtr dtype) { | |||
| auto int_dypte = dyn_cast<Int>(dtype); | |||
| if (int_dypte != nullptr) { | |||
| switch (int_dypte->nbits()) { | |||
| case 8: | |||
| case kBit8: | |||
| data = std::make_shared<Int8Imm>(obj); | |||
| break; | |||
| case 16: | |||
| case kBit16: | |||
| data = std::make_shared<Int16Imm>(obj); | |||
| break; | |||
| case 32: | |||
| case kBit32: | |||
| data = std::make_shared<Int32Imm>(obj); | |||
| break; | |||
| case 64: | |||
| case kBit64: | |||
| data = std::make_shared<Int64Imm>(obj); | |||
| break; | |||
| default: | |||
| @@ -383,16 +387,16 @@ ValuePtr ConvertNumberWithType(const T &obj, TypePtr dtype) { | |||
| auto uint_dypte = dyn_cast<UInt>(dtype); | |||
| if (uint_dypte != nullptr) { | |||
| switch (uint_dypte->nbits()) { | |||
| case 8: | |||
| case kBit8: | |||
| data = std::make_shared<UInt8Imm>(obj); | |||
| break; | |||
| case 16: | |||
| case kBit16: | |||
| data = std::make_shared<UInt16Imm>(obj); | |||
| break; | |||
| case 32: | |||
| case kBit32: | |||
| data = std::make_shared<UInt32Imm>(obj); | |||
| break; | |||
| case 64: | |||
| case kBit64: | |||
| data = std::make_shared<UInt64Imm>(obj); | |||
| break; | |||
| default: | |||
| @@ -404,10 +408,10 @@ ValuePtr ConvertNumberWithType(const T &obj, TypePtr dtype) { | |||
| auto float_dypte = dyn_cast<Float>(dtype); | |||
| if (float_dypte != nullptr) { | |||
| switch (float_dypte->nbits()) { | |||
| case 32: | |||
| case kBit32: | |||
| data = std::make_shared<FP32Imm>(obj); | |||
| break; | |||
| case 64: | |||
| case kBit64: | |||
| data = std::make_shared<FP64Imm>(obj); | |||
| break; | |||
| default: | |||
| @@ -519,20 +519,16 @@ AnfNodePtr Parser::ParseConstant(const FunctionBlockPtr &, const py::object &nod | |||
| py::object obj = python_adapter::GetPyObjAttr(node, "value"); | |||
| if (py::isinstance<py::bool_>(obj)) { | |||
| MS_LOG(INFO) << "The Constant is bool:" << (std::string)py::str(obj); | |||
| auto data = py::cast<bool>(obj); | |||
| return NewValueNode(data); | |||
| return NewValueNode(py::cast<bool>(obj)); | |||
| } else if (py::isinstance<py::int_>(obj)) { | |||
| MS_LOG(INFO) << "The Constant is int64_t:" << (std::string)py::str(obj); | |||
| auto data = py::cast<int64_t>(obj); | |||
| return NewValueNode(data); | |||
| return NewValueNode(py::cast<int64_t>(obj)); | |||
| } else if (py::isinstance<py::float_>(obj)) { | |||
| MS_LOG(INFO) << "The Constant is float:" << (std::string)py::str(obj); | |||
| auto data = py::cast<float>(obj); | |||
| return NewValueNode(data); | |||
| return NewValueNode(py::cast<float>(obj)); | |||
| } else if (py::isinstance<py::str>(obj)) { | |||
| MS_LOG(INFO) << "The Constant is string:" << (std::string)py::str(obj); | |||
| auto data = py::cast<std::string>(obj); | |||
| return NewValueNode(data); | |||
| return NewValueNode(py::cast<std::string>(obj)); | |||
| } else if (py::isinstance<py::none>(obj)) { | |||
| MS_LOG(INFO) << "The Constant is none:" << (std::string)py::str(obj); | |||
| return NewValueNode(kNone); | |||
| @@ -287,39 +287,39 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa | |||
| MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); | |||
| MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list()); | |||
| // Join the last eval arguments and current arguments to check if there are loop variant. | |||
| auto joined_args_spec_list = AbstractJoin(args_spec_list, last_context->args_spec_list()); | |||
| MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list); | |||
| auto joined_args_spec_list_1 = AbstractJoin(args_spec_list, last_context->args_spec_list()); | |||
| MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list_1); | |||
| // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. | |||
| if (!(joined_args_spec_list == args_spec_list)) { | |||
| if (!(joined_args_spec_list_1 == args_spec_list)) { | |||
| func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| func_graph_->joined_shapes_.clear(); | |||
| std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), | |||
| std::transform(joined_args_spec_list_1.begin(), joined_args_spec_list_1.end(), | |||
| std::back_inserter(func_graph_->joined_shapes_), | |||
| [](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); | |||
| joined_args_spec_list = NormalizeArgs(joined_args_spec_list); | |||
| joined_args_spec_list_1 = NormalizeArgs(joined_args_spec_list_1); | |||
| MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; | |||
| } | |||
| return joined_args_spec_list; | |||
| return joined_args_spec_list_1; | |||
| } | |||
| } | |||
| if (trace_.size() != 0) { | |||
| MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); | |||
| MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back()); | |||
| // Join the last eval arguments and current arguments to check if there are loop variant. | |||
| auto joined_args_spec_list = AbstractJoin(args_spec_list, trace_.back()); | |||
| auto joined_args_spec_list_2 = AbstractJoin(args_spec_list, trace_.back()); | |||
| // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. | |||
| if (!(joined_args_spec_list == args_spec_list)) { | |||
| trace_.push_back(joined_args_spec_list); | |||
| if (!(joined_args_spec_list_2 == args_spec_list)) { | |||
| trace_.push_back(joined_args_spec_list_2); | |||
| func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| func_graph_->joined_shapes_.clear(); | |||
| std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), | |||
| std::transform(joined_args_spec_list_2.begin(), joined_args_spec_list_2.end(), | |||
| std::back_inserter(func_graph_->joined_shapes_), | |||
| [](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); | |||
| joined_args_spec_list = NormalizeArgs(joined_args_spec_list); | |||
| joined_args_spec_list_2 = NormalizeArgs(joined_args_spec_list_2); | |||
| MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; | |||
| } | |||
| MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); | |||
| return joined_args_spec_list; | |||
| MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list_2); | |||
| return joined_args_spec_list_2; | |||
| } else { | |||
| trace_.push_back(args_spec_list); | |||
| } | |||