| @@ -51,7 +51,7 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||
| MS_LOG(INFO) << "BiasAdd's attr axis is set to default"; | |||
| attr->axis = {1}; | |||
| } else { | |||
| attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); | |||
| attr->axis = CastToInt(prim.GetAttr("axis"), true); | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| @@ -49,7 +49,7 @@ int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i | |||
| MS_LOG(WARNING) << "get axis failed"; | |||
| attr->axis = {0}; | |||
| } else { | |||
| attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); | |||
| attr->axis = CastToInt(prim.GetAttr("axis"), true); | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| @@ -51,7 +51,7 @@ int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto prim_axis = GetValue<int>(prim.GetAttr("axis")); | |||
| auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front(); | |||
| attr->axis = prim_axis; | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| @@ -139,21 +139,21 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT | |||
| } else { | |||
| attr->format = schema::Format::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||
| auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||
| auto dilation = CastToInt(prim.GetAttr("dilation"), true); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||
| auto stride = CastToInt(prim.GetAttr("stride"), true); | |||
| attr->strideH = stride[2]; | |||
| attr->strideW = stride[3]; | |||
| @@ -175,7 +175,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT | |||
| int channel_mutiplier = 1; | |||
| if (prim.GetAttr("channel_mutiplier") != nullptr) { | |||
| channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); | |||
| channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front(); | |||
| } | |||
| attr->channelMultiplier = channel_mutiplier; | |||
| @@ -212,25 +212,25 @@ void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::Primitive | |||
| } else { | |||
| attr->format = schema::Format::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||
| auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||
| auto dilation = CastToInt(prim.GetAttr("dilation"), true); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||
| auto stride = CastToInt(prim.GetAttr("stride"), true); | |||
| attr->strideH = stride[2]; | |||
| attr->strideW = stride[3]; | |||
| attr->channelOut = GetValue<int>(prim.GetAttr("out_channel")); | |||
| attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); | |||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||
| if (pad_mode == "valid") { | |||
| @@ -270,7 +270,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||
| MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| int group = GetValue<int>(groupAttr); | |||
| int group = CastToInt(groupAttr, false).front(); | |||
| if (group > 1) { | |||
| PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); | |||
| } else { | |||
| @@ -94,7 +94,7 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->group = GetValue<int>(prim.GetAttr("group")); | |||
| attr->group = CastToInt(prim.GetAttr("group"), false).front(); | |||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||
| if (format == "NCHW") { | |||
| attr->format = schema::Format_NCHW; | |||
| @@ -103,25 +103,25 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod | |||
| } else { | |||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||
| auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||
| auto dilation = CastToInt(prim.GetAttr("dilation"), true); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||
| auto stride = CastToInt(prim.GetAttr("stride"), true); | |||
| attr->strideH = stride[0]; | |||
| attr->strideW = stride[1]; | |||
| attr->channelOut = GetValue<int>(prim.GetAttr("out_channel")); | |||
| attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); | |||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||
| if (pad_mode == "valid") { | |||
| attr->padMode = schema::PadMode_VALID; | |||
| @@ -92,7 +92,7 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->group = GetValue<int>(prim.GetAttr("group")); | |||
| attr->group = CastToInt(prim.GetAttr("group"), false).front(); | |||
| if (attr->group > 1) { | |||
| this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput; | |||
| } | |||
| @@ -104,25 +104,25 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode | |||
| } else { | |||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||
| auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||
| auto dilation = CastToInt(prim.GetAttr("dilation"), true); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||
| auto stride = CastToInt(prim.GetAttr("stride"), true); | |||
| attr->strideH = stride[0]; | |||
| attr->strideW = stride[1]; | |||
| attr->channelOut = GetValue<int>(prim.GetAttr("out_channel")); | |||
| attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); | |||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||
| if (pad_mode == "valid") { | |||
| @@ -132,21 +132,21 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv | |||
| } else { | |||
| attr->format = schema::Format::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||
| auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||
| auto dilation = CastToInt(prim.GetAttr("dilation"), true); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||
| auto stride = CastToInt(prim.GetAttr("stride"), true); | |||
| attr->strideH = stride[0]; | |||
| attr->strideW = stride[1]; | |||
| @@ -168,7 +168,7 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv | |||
| int channel_mutiplier = 1; | |||
| if (prim.GetAttr("channel_mutiplier") != nullptr) { | |||
| channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); | |||
| channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front(); | |||
| } | |||
| attr->channelMultiplier = channel_mutiplier; | |||
| @@ -195,25 +195,25 @@ void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::Primi | |||
| } else { | |||
| attr->format = schema::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list")); | |||
| auto pad_list = CastToInt(prim.GetAttr("pad_list"), true); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||
| auto dilation = CastToInt(prim.GetAttr("dilation"), true); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||
| auto stride = CastToInt(prim.GetAttr("stride"), true); | |||
| attr->strideH = stride[0]; | |||
| attr->strideW = stride[1]; | |||
| attr->channelOut = GetValue<int>(prim.GetAttr("out_channel")); | |||
| attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front(); | |||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||
| if (pad_mode == "valid" || pad_mode == "VALID") { | |||
| @@ -248,7 +248,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i | |||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| int group = GetValue<int>(prim.GetAttr("group")); | |||
| int group = CastToInt(prim.GetAttr("group"), false).front(); | |||
| if (group == 1) { | |||
| PopulaterDeConv2DSingleGroup(prim, this->primitive_, group); | |||
| } else if (group > 1) { | |||
| @@ -86,27 +86,27 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode | |||
| } else { | |||
| attr->format = schema::Format::Format_NUM_OF_FORMAT; | |||
| } | |||
| auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pads")); | |||
| auto pad_list = CastToInt(prim.GetAttr("pads"), true); | |||
| attr->padUp = pad_list[0]; | |||
| attr->padDown = pad_list[1]; | |||
| attr->padLeft = pad_list[2]; | |||
| attr->padRight = pad_list[3]; | |||
| auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation")); | |||
| auto dilation = CastToInt(prim.GetAttr("dilation"), true); | |||
| attr->dilateH = dilation[0]; | |||
| attr->dilateW = dilation[1]; | |||
| if (utils::isa<ValueSequeue>(prim.GetAttr("kernel_size"))) { | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size")); | |||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true); | |||
| attr->kernelH = kernel_size[0]; | |||
| attr->kernelW = kernel_size[1]; | |||
| } else { | |||
| auto kernel_size = GetValue<int>(prim.GetAttr("kernel_size")); | |||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), false).front(); | |||
| attr->kernelH = kernel_size; | |||
| attr->kernelW = kernel_size; | |||
| } | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride")); | |||
| auto stride = CastToInt(prim.GetAttr("stride"), true); | |||
| attr->strideH = stride[2]; | |||
| attr->strideW = stride[3]; | |||
| @@ -124,7 +124,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode | |||
| } else { | |||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||
| } | |||
| auto channel_multiplier = GetValue<int>(prim.GetAttr("channel_multiplier")); | |||
| auto channel_multiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front(); | |||
| attr->channelMultiplier = channel_multiplier; | |||
| MS_ASSERT(inputs.size() == kAnfPopulaterTwo); | |||
| @@ -53,7 +53,7 @@ int ExpandDims::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> | |||
| // use axis instead of dim | |||
| if (inputs[1]->isa<ValueNode>()) { | |||
| auto axis_tensor = inputs[1]->cast<ValueNodePtr>(); | |||
| int axis = GetValue<int>(axis_tensor->value()); | |||
| int axis = CastToInt(axis_tensor->value(), false).front(); | |||
| attr->dim = axis; | |||
| } else { | |||
| MS_LOG(ERROR) << "input axis is not value node."; | |||
| @@ -59,7 +59,7 @@ int Gather::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||
| } | |||
| if (inputs[2]->isa<ValueNode>()) { | |||
| ValueNodePtr axis_tensor = inputs[2]->cast<ValueNodePtr>(); | |||
| int axis = GetValue<int>(axis_tensor->value()); | |||
| int axis = CastToInt(axis_tensor->value(), false).front(); | |||
| gather_attr->axis = axis; | |||
| } else { | |||
| MS_LOG(ERROR) << "input axis is not value node."; | |||
| @@ -48,7 +48,7 @@ int OneHot::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||
| } | |||
| attr->axis = -1; | |||
| if (prim.GetAttr("axis") != nullptr) { | |||
| attr->axis = GetValue<int>(prim.GetAttr("axis")); | |||
| attr->axis = CastToInt(prim.GetAttr("axis"), false).front(); | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| @@ -110,11 +110,11 @@ int Pooling::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| } | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("ksize")); | |||
| auto kernel_size = CastToInt(prim.GetAttr("ksize"), true); | |||
| attr->windowH = kernel_size[2]; | |||
| attr->windowW = kernel_size[3]; | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("strides")); | |||
| auto stride = CastToInt(prim.GetAttr("strides"), true); | |||
| attr->strideH = stride[2]; | |||
| attr->strideW = stride[3]; | |||
| this->primitive_->value.value = attr; | |||
| @@ -99,11 +99,11 @@ int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> | |||
| attr->padMode = schema::PadMode_NOTSET; | |||
| } | |||
| auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("ksize")); | |||
| auto kernel_size = CastToInt(prim.GetAttr("ksize"), true); | |||
| attr->windowH = kernel_size[2]; | |||
| attr->windowW = kernel_size[3]; | |||
| auto stride = GetValue<std::vector<int>>(prim.GetAttr("strides")); | |||
| auto stride = CastToInt(prim.GetAttr("strides"), true); | |||
| attr->strideH = stride[2]; | |||
| attr->strideW = stride[3]; | |||
| this->primitive_->value.value = attr; | |||
| @@ -180,6 +180,35 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| std::vector<int> CastToInt(const ValuePtr value, bool is_vector) { | |||
| if (value == nullptr) { | |||
| MS_LOG(WARNING) << "valueptr is nullptr."; | |||
| return {}; | |||
| } | |||
| std::vector<int> cur_value; | |||
| if (is_vector) { | |||
| if (!utils::isa<ValueSequeuePtr>(value)) { | |||
| MS_LOG(WARNING) << "valueptr is not a sequence, value may be a scalar."; | |||
| return {}; | |||
| } | |||
| if (value->cast<ValueSequeuePtr>()->value().front()->type()->type_name() == "Int64Imm") { | |||
| auto origin_value = GetValue<std::vector<int64_t>>(value); | |||
| for (size_t index = 0; index < origin_value.size(); ++index) { | |||
| cur_value.push_back(static_cast<int>(origin_value[index])); | |||
| } | |||
| } else { | |||
| cur_value = GetValue<std::vector<int>>(value); | |||
| } | |||
| } else { | |||
| if (value->type_name() == "Int64Imm") { | |||
| cur_value.push_back(static_cast<int>(GetValue<int64_t>(value))); | |||
| } else { | |||
| cur_value.push_back(GetValue<int>(value)); | |||
| } | |||
| } | |||
| return cur_value; | |||
| } | |||
| void PrimitiveC::CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax) { | |||
| const float qmin = 0; | |||
| const float qmax = 255; | |||
| @@ -52,6 +52,8 @@ static std::map<std::string, schema::ActivationType> kActivationTypeMap{{"ReLU", | |||
| {"Sigmoid", schema::ActivationType_SIGMOID}, | |||
| {"HSwish", schema::ActivationType_HSWISH}, | |||
| {"HSigmoid", schema::ActivationType_HSIGMOID}}; | |||
| std::vector<int> CastToInt(const ValuePtr value, bool is_vector); | |||
| class PrimitiveC : public mindspore::Primitive { | |||
| public: | |||
| // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). | |||
| @@ -87,7 +87,7 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||
| attr->axes.emplace_back(elem->value()); | |||
| } | |||
| } else { | |||
| int axes_item = GetValue<int>(value); | |||
| int axes_item = CastToInt(value, false).front(); | |||
| attr->axes.push_back(axes_item); | |||
| } | |||
| } | |||
| @@ -63,7 +63,7 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||
| attr->shape.emplace_back(static_cast<int>(elem->value())); | |||
| } | |||
| } else { | |||
| int dim = GetValue<int>(val); | |||
| int dim = CastToInt(val, false).front(); | |||
| attr->shape = {dim}; | |||
| } | |||
| } | |||
| @@ -67,7 +67,7 @@ int Resize::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||
| MS_LOG(ERROR) << "wrong resize type"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> targetSize = GetValue<std::vector<int>>(prim.GetAttr("size")); | |||
| std::vector<int> targetSize = CastToInt(prim.GetAttr("size"), true); | |||
| attr->newHeight = targetSize[0]; | |||
| attr->newWidth = targetSize[1]; | |||
| attr->alignCorners = GetValue<bool>(prim.GetAttr("align_corners")); | |||
| @@ -43,7 +43,7 @@ int SoftMax::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto prim_axis = GetValue<int>(prim.GetAttr("axis")); | |||
| auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front(); | |||
| attr->axis = prim_axis; | |||
| this->primitive_->value.value = attr; | |||
| if (this->primitive_->value.value == nullptr) { | |||
| @@ -50,7 +50,7 @@ int Squeeze::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in | |||
| MS_LOG(INFO) << "Squeeze's attr xis is set to default"; | |||
| attr->axis = {0}; | |||
| } else { | |||
| attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis")); | |||
| attr->axis = CastToInt(prim.GetAttr("axis"), true); | |||
| } | |||
| this->primitive_->value.value = attr; | |||
| } | |||
| @@ -73,11 +73,11 @@ int StridedSlice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr | |||
| MS_LOG(ERROR) << "new StridedSlice failed"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->beginMask = GetValue<int>(prim.GetAttr("begin_mask")); | |||
| attr->endMask = GetValue<int>(prim.GetAttr("end_mask")); | |||
| attr->ellipsisMask = GetValue<int>(prim.GetAttr("ellipsis_mask")); | |||
| attr->newAxisMask = GetValue<int>(prim.GetAttr("new_axis_mask")); | |||
| attr->shrinkAxisMask = GetValue<int>(prim.GetAttr("shrink_axis_mask")); | |||
| attr->beginMask = CastToInt(prim.GetAttr("begin_mask"), false).front(); | |||
| attr->endMask = CastToInt(prim.GetAttr("end_mask"), false).front(); | |||
| attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask"), false).front(); | |||
| attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask"), false).front(); | |||
| attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask"), false).front(); | |||
| auto inputNodeFirst = inputs[kAnfPopulaterOne]; | |||
| std::vector<int> beginVec; | |||
| GetAttrDataFromInput(inputNodeFirst, &beginVec); | |||
| @@ -56,7 +56,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input | |||
| MS_LOG(INFO) << "Tile's attr dims is set to default"; | |||
| attr->dims = {1}; | |||
| } else { | |||
| attr->dims = GetValue<std::vector<int>>(prim.GetAttr("dims")); | |||
| attr->dims = CastToInt(prim.GetAttr("dims"), true); | |||
| } | |||
| if (inputs.size() == kAnfPopulaterTwo) { | |||
| auto inputNode = inputs[kAnfPopulaterOne]; | |||
| @@ -75,7 +75,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input | |||
| attr->multiples.emplace_back(elem->value()); | |||
| } | |||
| } else { | |||
| int multiple = GetValue<int>(value); | |||
| int multiple = CastToInt(value, false).front(); | |||
| attr->multiples = {multiple}; | |||
| } | |||
| } | |||
| @@ -48,7 +48,7 @@ int UnsortedSegmentSum::UnPackAttr(const Primitive &prim, const std::vector<AnfN | |||
| std::unique_ptr<schema::UnsortedSegmentSumT> attr = std::make_unique<schema::UnsortedSegmentSumT>(); | |||
| if (inputs[2]->isa<ValueNode>()) { | |||
| ValuePtr value = inputs[2]->cast<ValueNodePtr>()->value(); | |||
| attr->numSegments = GetValue<int>(value); | |||
| attr->numSegments = CastToInt(value, false).front(); | |||
| this->primitive_->value.value = attr.release(); | |||
| } | |||
| } | |||
| @@ -314,7 +314,9 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, s | |||
| return RET_ERROR; | |||
| } | |||
| auto input_index_key = | |||
| get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(GetValue<int>(value_node->value())); | |||
| get_item_input_cnode->fullname_with_scope() + "_o:" + | |||
| std::to_string(value_node->value()->type_name() == "Int64Imm" ? GetValue<int64_t>(value_node->value()) | |||
| : GetValue<int>(value_node->value())); | |||
| auto iter = node_id_map_.find(input_index_key); | |||
| if (iter == node_id_map_.end()) { | |||
| #ifdef SUPPORT_TRAIN | |||