| @@ -556,8 +556,6 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { | |||
| ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { | |||
| auto abstract = parameter == nullptr ? std::make_shared<abstract::AbstractNone>() : parameter->abstract(); | |||
| auto new_parameter = NewParameter(abstract); | |||
| MS_EXCEPTION_IF_NULL(new_parameter); | |||
| // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter | |||
| if (parameter != nullptr) { | |||
| new_parameter->set_name(parameter->name()); | |||
| @@ -574,7 +572,6 @@ ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { | |||
| ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract) { | |||
| ParameterPtr new_parameter = add_parameter(); | |||
| new_parameter->set_abstract(abstract); | |||
| MS_EXCEPTION_IF_NULL(new_parameter); | |||
| // create kernel_info form new parameter | |||
| SetKernelInfoForNode(new_parameter); | |||
| AnfAlgo::SetGraphId(graph_id_, new_parameter.get()); | |||
| @@ -130,9 +130,6 @@ class SymbolResolver { | |||
| AnfNodePtr resolved_node() { return resolved_node_; } | |||
| // Resolve result | |||
| py::object result_; | |||
| private: | |||
| // namespace where the symbol locates | |||
| NameSpacePtr namespace_; | |||
| @@ -140,6 +137,8 @@ class SymbolResolver { | |||
| SymbolPtr symbol_; | |||
| // the node that has been resolved | |||
| AnfNodePtr resolved_node_; | |||
| // Resolve result | |||
| py::object result_; | |||
| }; | |||
| using SymbolResolverPtr = std::shared_ptr<SymbolResolver>; | |||
| // Resolve symbol in namespace. | |||
| @@ -51,10 +51,9 @@ class ResourceBase { | |||
| bool HasResult(const std::string &key) const { return results_.count(key) != 0; } | |||
| std::unordered_map<std::string, Any> results_; | |||
| protected: | |||
| FuncGraphManagerPtr manager_; | |||
| std::unordered_map<std::string, Any> results_; | |||
| }; | |||
| using ResourceBasePtr = std::shared_ptr<pipeline::ResourceBase>; | |||
| @@ -22,7 +22,6 @@ namespace mindspore { | |||
| namespace abstract { | |||
| class Evaluator; | |||
| class AnalysisEngine; | |||
| AbstractFunctionPtr AbstractFunction::MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list) { | |||
| if (func_list.size() == 1) { | |||
| return func_list[0]; | |||
| @@ -102,7 +101,7 @@ AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) { | |||
| return std::make_shared<AbstractFuncUnion>(this_func, other); | |||
| } | |||
| auto other_union = dyn_cast<AbstractFuncUnion>(other); | |||
| MS_EXCEPTION_IF_NULL(other); | |||
| MS_EXCEPTION_IF_NULL(other_union); | |||
| if (other_union->IsSuperSet(this_func)) { | |||
| return other; | |||
| } | |||
| @@ -110,7 +109,7 @@ AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) { | |||
| } | |||
| void AbstractFuncUnion::Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const { | |||
| for (AbstractFuncAtomPtr poss : func_list_) { | |||
| for (const AbstractFuncAtomPtr &poss : func_list_) { | |||
| visit_func(poss); | |||
| } | |||
| } | |||
| @@ -123,15 +122,12 @@ bool AbstractFuncUnion::operator==(const AbstractFunction &other) const { | |||
| if (func_list_.size() != other_union->func_list_.size()) { | |||
| return false; | |||
| } | |||
| if (func_list_ == other_union->func_list_) { | |||
| return true; | |||
| } | |||
| return false; | |||
| return func_list_ == other_union->func_list_; | |||
| } | |||
| std::size_t AbstractFuncUnion::hash() const { | |||
| std::size_t hash_sum = 0; | |||
| for (auto f : func_list_) { | |||
| for (const auto &f : func_list_) { | |||
| MS_EXCEPTION_IF_NULL(f); | |||
| hash_sum = hash_combine(hash_sum, f->hash()); | |||
| } | |||
| @@ -144,10 +140,7 @@ bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| } | |||
| auto other_prim = static_cast<const PrimitiveAbstractClosure *>(&other); | |||
| MS_EXCEPTION_IF_NULL(prim_); | |||
| if (prim_ == other_prim->prim_ && tracking_id() == other_prim->tracking_id()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| return (prim_ == other_prim->prim_ && tracking_id() == other_prim->tracking_id()); | |||
| } | |||
| std::size_t PrimitiveAbstractClosure::hash() const { | |||
| @@ -165,11 +158,8 @@ bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| return false; | |||
| } | |||
| auto other_fg = static_cast<const FuncGraphAbstractClosure *>(&other); | |||
| if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_ && | |||
| tracking_id() == other_fg->tracking_id()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| return func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_ && | |||
| tracking_id() == other_fg->tracking_id(); | |||
| } | |||
| std::size_t FuncGraphAbstractClosure::hash() const { | |||
| @@ -195,10 +185,7 @@ bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) con | |||
| return false; | |||
| } | |||
| auto other_meta_fg = static_cast<const MetaFuncGraphAbstractClosure *>(&other); | |||
| if (meta_func_graph_ == other_meta_fg->meta_func_graph_ && tracking_id() == other_meta_fg->tracking_id()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| return meta_func_graph_ == other_meta_fg->meta_func_graph_ && tracking_id() == other_meta_fg->tracking_id(); | |||
| } | |||
| std::size_t MetaFuncGraphAbstractClosure::hash() const { | |||
| @@ -226,10 +213,7 @@ bool PartialAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| if (args_spec_list_.size() != other_partial->args_spec_list_.size()) { | |||
| return false; | |||
| } | |||
| if (args_spec_list_ == other_partial->args_spec_list_) { | |||
| return true; | |||
| } | |||
| return false; | |||
| return args_spec_list_ == other_partial->args_spec_list_; | |||
| } | |||
| std::size_t PartialAbstractClosure::hash() const { | |||
| @@ -255,10 +239,7 @@ bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) cons | |||
| return false; | |||
| } | |||
| auto other_transformed = static_cast<const JTransformedAbstractClosure *>(&other); | |||
| if (fn_ == other_transformed->fn_) { | |||
| return true; | |||
| } | |||
| return false; | |||
| return fn_ == other_transformed->fn_; | |||
| } | |||
| std::size_t JTransformedAbstractClosure::hash() const { | |||
| @@ -278,10 +259,7 @@ bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| if (args_spec_list_.size() != other_virtual->args_spec_list_.size()) { | |||
| return false; | |||
| } | |||
| if (args_spec_list_ == other_virtual->args_spec_list_) { | |||
| return true; | |||
| } | |||
| return false; | |||
| return args_spec_list_ == other_virtual->args_spec_list_; | |||
| } | |||
| std::size_t VirtualAbstractClosure::hash() const { | |||
| @@ -319,10 +297,7 @@ bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) co | |||
| if (args_spec_list_.size() != other_typed->args_spec_list_.size()) { | |||
| return false; | |||
| } | |||
| if (args_spec_list_ == other_typed->args_spec_list_) { | |||
| return true; | |||
| } | |||
| return false; | |||
| return args_spec_list_ == other_typed->args_spec_list_; | |||
| } | |||
| std::size_t TypedPrimitiveAbstractClosure::hash() const { | |||
| @@ -346,10 +321,7 @@ std::string TypedPrimitiveAbstractClosure::ToString() const { | |||
| } | |||
| bool DummyAbstractClosure::operator==(const AbstractFunction &other) const { | |||
| if (!other.isa<DummyAbstractClosure>()) { | |||
| return false; | |||
| } | |||
| return true; | |||
| return !other.isa<DummyAbstractClosure>(); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -238,6 +238,7 @@ std::string AbstractError::ToString() const { | |||
| std::ostringstream buffer; | |||
| auto value_track = GetValueTrack(); | |||
| MS_EXCEPTION_IF_NULL(value_track); | |||
| MS_EXCEPTION_IF_NULL(node_); | |||
| buffer << type_name() << "(" | |||
| << "Value: " << value_track->ToString() << ", Node: " << node_->DebugString() << ")"; | |||
| return buffer.str(); | |||
| @@ -594,6 +595,8 @@ bool AbstractTensor::equal_to(const AbstractTensor &other) const { | |||
| if (v1->isa<AnyValue>() && v2->isa<AnyValue>()) { | |||
| is_value_equal = true; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(element_); | |||
| MS_EXCEPTION_IF_NULL(other.element_); | |||
| return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal; | |||
| } | |||
| @@ -912,6 +915,7 @@ AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) { | |||
| if (other_jtagged == nullptr) { | |||
| AbstractTypeJoinLogging(shared_from_base<AbstractBase>(), other); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(element_); | |||
| auto joined_elem = element_->Join(other_jtagged->element_); | |||
| return std::make_shared<AbstractJTagged>(joined_elem); | |||
| } | |||
| @@ -85,7 +85,7 @@ AnalysisContextPtr AnalysisContext::FindOwnOrParentContext(const FuncGraphPtr &f | |||
| oss << "nullptr"; | |||
| } | |||
| oss << " extant context list: {"; | |||
| for (auto iter : extant_context_cache_) { | |||
| for (const auto &iter : extant_context_cache_) { | |||
| if (iter.first == nullptr) { | |||
| oss << " [graph: nullptr"; | |||
| } else { | |||
| @@ -108,10 +108,7 @@ AnalysisContextPtr AnalysisContext::DummyContext() { | |||
| } | |||
| bool AnalysisContext::IsDummyContext() { | |||
| if (parent_ == nullptr && func_graph_ == nullptr && args_spec_list_.empty()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| return parent_ == nullptr && func_graph_ == nullptr && args_spec_list_.empty(); | |||
| } | |||
| const AnalysisContextPtr kDummyAnalysisContext = | |||
| @@ -75,6 +75,7 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti | |||
| [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); }); | |||
| ShapeVector res = BroadcastShape(shp_x, shp_y); | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[1]); | |||
| if (res.empty()) { | |||
| MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," | |||
| << args_spec_list[1]->ToString(); | |||
| @@ -115,16 +116,17 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr | |||
| (void)CheckDtypeSame(op_name, tensor_base, tensor); | |||
| (void)CheckShapeSame(op_name, tensor_base, tensor); | |||
| } | |||
| auto element = tensor_base->element(); | |||
| MS_EXCEPTION_IF_NULL(element); | |||
| primitive->set_attr("N", MakeValue(SizeToLong(tuple_len))); | |||
| primitive->set_attr("T", tensor_base->element()->BuildType()); | |||
| primitive->set_attr("T", element->BuildType()); | |||
| AbstractTensorPtr ret = dyn_cast<AbstractTensor>(tensor_base->Broaden()); | |||
| MS_EXCEPTION_IF_NULL(ret); | |||
| auto ret_shape_ptr = ret->shape(); | |||
| MS_EXCEPTION_IF_NULL(ret_shape_ptr); | |||
| auto ret_shape = ret->shape()->shape(); | |||
| (void)ret_shape.insert(ret_shape.begin() + axis_value, tuple_len); | |||
| auto ret_shape = ret_shape_ptr->shape(); | |||
| (void)ret_shape.insert(ret_shape.begin() + axis_value, SizeToLong(tuple_len)); | |||
| ret->set_shape(std::make_shared<Shape>(ret_shape)); | |||
| return ret; | |||
| } | |||
| @@ -137,7 +139,8 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto shape = input->shape(); | |||
| if (shape->shape().size() != 1) { | |||
| MS_EXCEPTION_IF_NULL(shape); | |||
| if (shape->shape().empty()) { | |||
| MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1."; | |||
| } | |||
| ShapeVector ids_shape = {Shape::SHP_ANY}; | |||
| @@ -205,6 +208,10 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt | |||
| CheckArgsSize(op_name + " dout", dout->elements(), size_expected); | |||
| auto ids = CheckArg<AbstractTensor>(op_name, dout->elements(), 0); | |||
| auto ids_idx = CheckArg<AbstractTensor>(op_name, dout->elements(), 1); | |||
| auto ids_shape = ids->shape(); | |||
| auto ids_idx_shape = ids_idx->shape(); | |||
| MS_EXCEPTION_IF_NULL(ids_shape); | |||
| MS_EXCEPTION_IF_NULL(ids_idx_shape); | |||
| if (ids->shape()->shape().size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Dims of dout[0] of " << op_name << "' input must be 1."; | |||
| } | |||
| @@ -278,7 +285,6 @@ AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const Pri | |||
| const size_t size_expected = 3; | |||
| CheckArgsSize(op_name, args_spec_list, size_expected); | |||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||
| auto segment_ids = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(segment_ids); | |||
| @@ -426,12 +432,10 @@ AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitiveP | |||
| const size_t size_expected = 5; | |||
| CheckArgsSize(op_name, args_spec_list, size_expected); | |||
| auto hash_map = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(hash_map); | |||
| MS_EXCEPTION_IF_NULL(hash_map->shape()); | |||
| auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| auto indices_shp = indices->shape(); | |||
| MS_EXCEPTION_IF_NULL(indices); | |||
| MS_EXCEPTION_IF_NULL(indices_shp); | |||
| ShapeVector shape; | |||
| @@ -466,12 +470,10 @@ AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const Primiti | |||
| CheckArgsSize(op_name, args_spec_list, size_expected); | |||
| auto cache_table = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto cache_table_shp = cache_table->shape(); | |||
| MS_EXCEPTION_IF_NULL(cache_table); | |||
| MS_EXCEPTION_IF_NULL(cache_table_shp); | |||
| auto swap_cache_idx = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| auto swap_cache_idx_shp = swap_cache_idx->shape(); | |||
| MS_EXCEPTION_IF_NULL(swap_cache_idx); | |||
| MS_EXCEPTION_IF_NULL(swap_cache_idx_shp); | |||
| auto cache_table_shape = cache_table_shp->shape(); | |||
| @@ -501,12 +503,6 @@ AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitiveP | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(input_x); | |||
| MS_EXCEPTION_IF_NULL(input_x->shape()); | |||
| auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(indices); | |||
| MS_EXCEPTION_IF_NULL(indices->shape()); | |||
| ShapeVector shape; | |||
| shape.emplace_back(1); | |||
| @@ -520,7 +516,6 @@ AbstractBasePtr InferImplSubAndFilter(const AnalysisEnginePtr &, const Primitive | |||
| const std::string op_name = primitive->name(); | |||
| auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto input_x_shp = input_x->shape(); | |||
| MS_EXCEPTION_IF_NULL(input_x); | |||
| MS_EXCEPTION_IF_NULL(input_x_shp); | |||
| ShapeVector shape; | |||
| @@ -648,6 +643,7 @@ AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const Primitiv | |||
| MS_LOG(INFO) << "InferImplDynamicAssign " << args_spec_list[0]; | |||
| auto type = args_spec_list[0]->BuildType(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| if (type->type_id() == kObjectTypeRefKey) { | |||
| return args_spec_list[1]->Broaden(); | |||
| } else { | |||
| @@ -669,12 +665,10 @@ AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const Primit | |||
| const std::string op_name = primitive->name(); | |||
| auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto params_shp = params->shape(); | |||
| MS_EXCEPTION_IF_NULL(params); | |||
| MS_EXCEPTION_IF_NULL(params_shp); | |||
| auto params_shape = params_shp->shape(); | |||
| auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| auto indices_shp = indices->shape(); | |||
| MS_EXCEPTION_IF_NULL(indices); | |||
| MS_EXCEPTION_IF_NULL(indices_shp); | |||
| auto indices_shape = indices_shp->shape(); | |||
| auto indices_max_shape = indices_shp->max_shape(); | |||
| @@ -726,8 +720,8 @@ AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const Primitive | |||
| const std::string &op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(input->shape()); | |||
| auto shape = input->shape()->shape(); | |||
| bool has_dyn_shape = std::any_of(shape.begin(), shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; }); | |||
| ShapeVector tensor_shp({static_cast<int64_t>(shape.size())}); | |||
| if (has_dyn_shape) { | |||
| @@ -750,6 +744,7 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr | |||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto input_shp = input->shape()->shape(); | |||
| ValuePtr perm = primitive->GetAttr("perm"); | |||
| MS_EXCEPTION_IF_NULL(perm); | |||
| auto perm_val = perm->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(perm_val); | |||
| auto perm_val_data = perm_val->value(); | |||
| @@ -788,6 +783,7 @@ AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| x_min_shape = x_shape; | |||
| } | |||
| ValuePtr sh = primitive->GetAttr("shape"); | |||
| MS_EXCEPTION_IF_NULL(sh); | |||
| auto reshape_value_tuple = sh->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(reshape_value_tuple); | |||
| auto reshape_tuple = reshape_value_tuple->value(); | |||
| @@ -866,7 +862,7 @@ AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &pr | |||
| ValuePtr axis = primitive->GetAttr("axis"); | |||
| int64_t axis_value = CheckAxis(op_name, axis, -(rank + 1), rank); | |||
| uint64_t axis_value_pos = LongToUlong(GetPositiveAxis(axis_value, LongToSize(rank))); | |||
| int64_t output_num_value = primitive->GetAttr("output_num")->cast<Int64ImmPtr>()->value(); | |||
| int64_t output_num_value = GetValue<int64_t>(primitive->GetAttr("output_num")); | |||
| if ((x_shape[axis_value_pos] != Shape::SHP_ANY) && (x_shape[axis_value_pos] % output_num_value != 0)) { | |||
| MS_LOG(EXCEPTION) << "x_shape[" << axis_value_pos << "] = " << x_shape[axis_value_pos] | |||
| << " must be divisible by output_num = " << output_num_value; | |||
| @@ -91,6 +91,7 @@ int64_t InferImplReduceFuncCheckAxis(const int64_t &axis, const size_t dim) { | |||
| void InferImplReduceFuncCalShape(ShapeVector *shape, const ShapeVector &x_shape, const ValuePtr &axis, | |||
| bool keep_dims_value) { | |||
| MS_EXCEPTION_IF_NULL(axis); | |||
| if (axis->isa<ValueTuple>() || axis->isa<ValueList>()) { | |||
| auto axis_ptr_list = | |||
| axis->isa<ValueTuple>() ? axis->cast<ValueTuplePtr>()->value() : axis->cast<ValueListPtr>()->value(); | |||
| @@ -57,12 +57,12 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr & | |||
| int64_t h_input = input_shape->shape()[H_INDEX]; | |||
| int64_t w_input = input_shape->shape()[W_INDEX]; | |||
| int64_t window = primitive->GetAttr("window")->cast<Int64ImmPtr>()->value(); | |||
| int64_t stride = primitive->GetAttr("stride")->cast<Int64ImmPtr>()->value(); | |||
| int64_t padding = primitive->GetAttr("pad")->cast<Int64ImmPtr>()->value(); | |||
| int64_t nan_opt = primitive->GetAttr("nan_opt")->cast<Int64ImmPtr>()->value(); | |||
| int64_t data_mode = primitive->GetAttr("data_mode")->cast<Int64ImmPtr>()->value(); | |||
| int64_t ceil_mode = primitive->GetAttr("ceil_mode")->cast<Int64ImmPtr>()->value(); | |||
| int64_t window = GetValue<int64_t>(primitive->GetAttr("window")); | |||
| int64_t stride = GetValue<int64_t>(primitive->GetAttr("stride")); | |||
| int64_t padding = GetValue<int64_t>(primitive->GetAttr("pad")); | |||
| int64_t nan_opt = GetValue<int64_t>(primitive->GetAttr("nan_opt")); | |||
| int64_t data_mode = GetValue<int64_t>(primitive->GetAttr("data_mode")); | |||
| int64_t ceil_mode = GetValue<int64_t>(primitive->GetAttr("ceil_mode")); | |||
| if (stride <= 0) { | |||
| MS_LOG(EXCEPTION) << "Invalid stride value: " << stride << ", should greater then 0"; | |||
| @@ -124,32 +124,6 @@ AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitiveP | |||
| return ret; | |||
| } | |||
| void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { | |||
| // check dimension, x > 1, others equal 1 | |||
| const std::string op_name = primitive->name(); | |||
| for (std::size_t i = 0; i < args_spec_list.size(); ++i) { | |||
| AbstractTensorPtr arg = CheckArg<AbstractTensor>(op_name, args_spec_list, i); | |||
| ShapePtr arg_shape = dyn_cast<Shape>(arg->GetShapeTrack()); | |||
| if (arg_shape == nullptr) { | |||
| MS_LOG(EXCEPTION) << op_name << " type of args[" << i << "] should be Shape, but " << arg->ToString(); | |||
| } | |||
| if (i == 0) { | |||
| if (arg_shape->shape().size() < 2) { | |||
| MS_LOG(EXCEPTION) << op_name << " shape of args[" << i | |||
| << "] should be TensorShape with dimension greater than 1, but shape: " | |||
| << arg_shape->ToString(); | |||
| } | |||
| continue; | |||
| } | |||
| if (arg_shape->shape().size() != 1) { | |||
| MS_LOG(EXCEPTION) << op_name << " shape of args[" << i | |||
| << "] should be TensorShape with dimension: 1, but shape: " << arg_shape->ToString(); | |||
| } | |||
| } | |||
| } | |||
| AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: five tensors(x, gamma, beta, mean, variance). | |||
| @@ -100,6 +100,7 @@ AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &prim, const AbstractBasePtrList &) { | |||
| ValuePtr name_value = prim->GetAttr("tag"); | |||
| MS_EXCEPTION_IF_NULL(name_value); | |||
| auto name = name_value->cast<StringImmPtr>(); | |||
| if (name == nullptr) { | |||
| MS_LOG(EXCEPTION) << "MakeRefKey attr tag should be a String " << name_value->ToString() << "."; | |||
| @@ -132,7 +133,9 @@ AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr | |||
| if (type->type_id() != kObjectTypeRef) { | |||
| MS_LOG(EXCEPTION) << "First input of get_ref_key should be a Ref but a " << type->ToString(); | |||
| } | |||
| return args_spec_list[0]->cast<AbstractRefPtr>()->ref(); | |||
| auto abs_ref = args_spec_list[0]->cast<AbstractRefPtr>(); | |||
| MS_EXCEPTION_IF_NULL(abs_ref); | |||
| return abs_ref->ref(); | |||
| } | |||
| AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| @@ -146,7 +149,9 @@ AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitiveP | |||
| if (type->type_id() != kObjectTypeRef) { | |||
| return args_spec_list[0]; | |||
| } | |||
| return args_spec_list[0]->cast<AbstractRefPtr>()->ref(); | |||
| auto abs_ref = args_spec_list[0]->cast<AbstractRefPtr>(); | |||
| MS_EXCEPTION_IF_NULL(abs_ref); | |||
| return abs_ref->ref(); | |||
| } | |||
| AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -185,6 +190,7 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP | |||
| if (args_spec_list.empty()) { | |||
| MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at least 1, but got 0"; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(args_spec_list[0]); | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||
| @@ -213,14 +219,16 @@ AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const Primitiv | |||
| << values_shp[0] << ", but got " << indices_shp[0]; | |||
| } | |||
| for (auto elem_type : dense_shape->ElementsType()) { | |||
| for (const auto &elem_type : dense_shape->ElementsType()) { | |||
| if (!elem_type->isa<Int>()) { | |||
| MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString(); | |||
| } | |||
| } | |||
| auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>(); | |||
| auto dense_shape_value = dense_shape->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(dense_shape_value); | |||
| auto shp = dense_shape_value->value(); | |||
| auto dense_shape_valuetuple = dense_shape_value->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(dense_shape_valuetuple); | |||
| auto shp = dense_shape_valuetuple->value(); | |||
| ShapeVector dense_shape_vec; | |||
| (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec), | |||
| [](const ValuePtr &e) -> int64_t { | |||
| @@ -229,7 +237,7 @@ AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const Primitiv | |||
| }); | |||
| if (dense_shape_vec.size() != values_shp.size()) { | |||
| MS_EXCEPTION(TypeError) << "The size of dense_shape must be the same with the dimension of values " | |||
| << values_shp.size() << ", but got " << dense_shape_value->size(); | |||
| << values_shp.size() << ", but got " << dense_shape_valuetuple->size(); | |||
| } | |||
| for (size_t i = 0; i < dense_shape_vec.size(); i++) { | |||
| if (dense_shape_vec[i] < 0) { | |||
| @@ -321,7 +329,7 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi | |||
| << values_shp[0] << ", but got " << indices_shp[0]; | |||
| } | |||
| for (auto elem_type : dense_shape->ElementsType()) { | |||
| for (const auto &elem_type : dense_shape->ElementsType()) { | |||
| if (!elem_type->isa<Int>()) { | |||
| MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString(); | |||
| } | |||
| @@ -502,9 +510,8 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||
| MS_EXCEPTION_IF_NULL(input_x); | |||
| auto attr = primitive->GetAttr("dst_type"); | |||
| if (attr == nullptr) { | |||
| auto input_dtype = args_spec_list[1]; | |||
| MS_EXCEPTION_IF_NULL(input_dtype); | |||
| attr = input_dtype->BuildValue(); | |||
| auto type_abs = CheckArg<AbstractType>(op_name, args_spec_list, 1); | |||
| attr = type_abs->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(attr); | |||
| primitive->set_attr("dst_type", attr); | |||
| } | |||
| @@ -74,7 +74,7 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP | |||
| abstract::CheckArgsSize(op_name, args_spec_list, kSwitchLayerInputNum); | |||
| auto index = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto &input_shape = index->shape()->shape(); | |||
| if (input_shape.size() != 0) { | |||
| if (!input_shape.empty()) { | |||
| MS_EXCEPTION(ValueError) << op_name << " index must be a 0 dimension tensor, but got a " << input_shape.size() | |||
| << " dimension tensor"; | |||
| } | |||
| @@ -86,7 +86,7 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP | |||
| AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); | |||
| AbstractBasePtrList branches = branches_abs->elements(); | |||
| const size_t maximum_layer_num = 1000; | |||
| if (branches.size() < 1 || branches.size() > maximum_layer_num) { | |||
| if (branches.empty() || branches.size() > maximum_layer_num) { | |||
| MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got " | |||
| << branches.size() << " branches."; | |||
| } | |||
| @@ -70,6 +70,7 @@ AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0); | |||
| ValuePtr keyPtr = key->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(keyPtr); | |||
| if (!keyPtr->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString(); | |||
| } | |||
| @@ -86,6 +87,7 @@ AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const Primitive | |||
| AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1); | |||
| ValuePtr key_value = key->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(key_value); | |||
| if (!key_value->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||
| } | |||
| @@ -110,6 +112,7 @@ AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr | |||
| slice_args.push_back(args_spec_list[index]); | |||
| } else if (args_spec_list[index]->isa<AbstractScalar>()) { | |||
| ValuePtr scalar_value = args_spec_list[index]->cast<AbstractScalarPtr>()->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(scalar_value); | |||
| if (scalar_value->isa<IntergerImm>()) { | |||
| slice_args.push_back(args_spec_list[index]); | |||
| } else if (scalar_value->isa<BoolImm>()) { | |||
| @@ -122,8 +125,9 @@ AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr | |||
| } else if (args_spec_list[index]->isa<AbstractTensor>()) { | |||
| auto arg = args_spec_list[index]->cast<AbstractTensorPtr>(); | |||
| TypePtr tensor_dtype = arg->element()->BuildType(); | |||
| auto value = arg->BuildValue()->cast<tensor::TensorPtr>(); | |||
| auto build_value = arg->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(build_value); | |||
| auto value = build_value->cast<tensor::TensorPtr>(); | |||
| if (value == nullptr) { | |||
| MS_EXCEPTION(TypeError) << "MakeSlice eval the input tensor must be a const tensor."; | |||
| } | |||
| @@ -159,6 +163,7 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra | |||
| AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr index_value = index->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(index_value); | |||
| if (!index_value->isa<Int64Imm>()) { | |||
| // when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element | |||
| // and continue | |||
| @@ -191,6 +196,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra | |||
| AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr index_value = index->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(index_value); | |||
| if (!index_value->isa<Int64Imm>()) { | |||
| MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got " | |||
| << index_value->ToString(); | |||
| @@ -237,6 +243,7 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr key_value = key->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(key_value); | |||
| if (!key_value->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||
| } | |||
| @@ -259,6 +266,7 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP | |||
| AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| ValuePtr key_value = key->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(key_value); | |||
| if (!key_value->isa<StringImm>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); | |||
| } | |||
| @@ -68,7 +68,7 @@ ShapePtr CalculateDynamicShape(const ShapePtr &shape1, const ShapePtr &shape2, c | |||
| continue; | |||
| } | |||
| if (shape1->shape()[i] == Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) { | |||
| if (shape1->min_shape().empty() || shape1->max_shape().empty()) { | |||
| if (shape1->min_shape().size() <= i || shape1->max_shape().size() <= i) { | |||
| MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() | |||
| << " has dynamic shape, but does not have min/max shape info."; | |||
| } | |||
| @@ -77,7 +77,7 @@ ShapePtr CalculateDynamicShape(const ShapePtr &shape1, const ShapePtr &shape2, c | |||
| continue; | |||
| } | |||
| if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] == Shape::SHP_ANY) { | |||
| if (shape2->min_shape().empty() || shape2->max_shape().empty()) { | |||
| if (shape2->min_shape().size() <= i || shape2->max_shape().size() <= i) { | |||
| MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() | |||
| << " has dynamic shape, but does not have min/max shape info."; | |||
| } | |||
| @@ -86,11 +86,11 @@ ShapePtr CalculateDynamicShape(const ShapePtr &shape1, const ShapePtr &shape2, c | |||
| continue; | |||
| } | |||
| // both shapes contains dynamic shape | |||
| if (shape1->min_shape().empty() || shape1->max_shape().empty()) { | |||
| if (shape1->min_shape().size() <= i || shape1->max_shape().size() <= i) { | |||
| MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString() | |||
| << " has dynamic shape, but does not have min/max shape info."; | |||
| } | |||
| if (shape2->min_shape().empty() || shape2->max_shape().empty()) { | |||
| if (shape2->min_shape().size() <= i || shape2->max_shape().size() <= i) { | |||
| MS_EXCEPTION(ValueError) << "Shape " << shape2->ToString() | |||
| << " has dynamic shape, but does not have min/max shape info."; | |||
| } | |||
| @@ -109,10 +109,10 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) { | |||
| // lengths of two shapes are not same, join failed | |||
| if (shape1->shape().size() != shape2->shape().size()) { | |||
| // special case: shape(1), shape() -> shape(1) | |||
| if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().size() == 0) { | |||
| if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().empty()) { | |||
| return shape1; | |||
| } | |||
| if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().size() == 0) { | |||
| if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().empty()) { | |||
| return shape2; | |||
| } | |||
| return nullptr; | |||
| @@ -138,7 +138,7 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) { | |||
| } | |||
| AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) { | |||
| if (args_spec_list.size() < 1) { | |||
| if (args_spec_list.empty()) { | |||
| MS_LOG(EXCEPTION) << "AbstractJoin requires at least 1 params, while the input size is " << args_spec_list.size() | |||
| << "."; | |||
| } | |||
| @@ -205,7 +205,7 @@ bool CheckType(const TypePtr &expected_type, const TypePtr &x) { | |||
| TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) { | |||
| MS_EXCEPTION_IF_NULL(predicate); | |||
| for (auto arg_type : args_type_list) { | |||
| for (const auto &arg_type : args_type_list) { | |||
| MS_EXCEPTION_IF_NULL(arg_type); | |||
| if (!CheckType(predicate, arg_type)) { | |||
| MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString(); | |||
| @@ -294,19 +294,6 @@ ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) { | |||
| return shp; | |||
| } | |||
| ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x, | |||
| const AbstractTensorPtr &tensor_y) { | |||
| mindspore::abstract::ShapePtr tensor_x_shape = tensor_x->shape(); | |||
| mindspore::abstract::ShapePtr tensor_y_shape = tensor_y->shape(); | |||
| // if is the same shape ,just return the x_shape | |||
| if (*tensor_x_shape == *tensor_y_shape) { | |||
| return tensor_x_shape; | |||
| } | |||
| auto x_shape = tensor_x_shape->shape(); | |||
| auto y_shape = tensor_y_shape->shape(); | |||
| return std::make_shared<Shape>(RealBroadcast(op, x_shape, y_shape)); | |||
| } | |||
| size_t TypeIdSize(const TypeId data_type) { | |||
| const size_t unsupported_type_error = 0; | |||
| auto iter = type_map.find(data_type); | |||