Merge pull request !4678 from yeyunpeng2020/master_cops_3tags/v0.7.0-beta
| @@ -61,18 +61,17 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto argmax_prim = this->primitive->value_as_ArgMax(); | |||
| std::vector<int> output_shape(input->shape()); | |||
| auto input_shape_size = input->shape().size(); | |||
| int axis = argmax_prim->axis() < 0 ? argmax_prim->axis() + input_shape_size : argmax_prim->axis(); | |||
| int axis = GetAxis() < 0 ? GetAxis() + input_shape_size : GetAxis(); | |||
| if (axis >= input_shape_size || axis < 0) { | |||
| MS_LOG(ERROR) << "Invalid axis " << argmax_prim->axis() << ", input shape size: " << input_shape_size; | |||
| MS_LOG(ERROR) << "Invalid axis " << GetAxis() << ", input shape size: " << input_shape_size; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (argmax_prim->topK() == 1 && !argmax_prim->keepDims()) { | |||
| if (GetTopK() == 1 && !GetKeepDims()) { | |||
| output_shape.erase(output_shape.begin() + axis); | |||
| } else { | |||
| output_shape[axis] = argmax_prim->topK(); | |||
| output_shape[axis] = GetTopK(); | |||
| } | |||
| output->set_shape(output_shape); | |||
| @@ -46,7 +46,7 @@ void ArgMin::SetKeepDims(bool keep_dims) {} | |||
| void ArgMin::SetAxisType(int axis_type) {} | |||
| #endif | |||
| int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) { | |||
| int ArgMin::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) { | |||
| MS_ASSERT(this->primitive != nullptr); | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| @@ -60,18 +60,17 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto argmin_prim = this->primitive->value_as_ArgMin(); | |||
| auto input_shape_size = input->shape().size(); | |||
| int axis = argmin_prim->axis() < 0 ? argmin_prim->axis() + input_shape_size : argmin_prim->axis(); | |||
| int axis = GetAxis() < 0 ? GetAxis() + input_shape_size : GetAxis(); | |||
| if (axis >= input_shape_size || axis < 0) { | |||
| MS_LOG(ERROR) << "Invalid axis " << argmin_prim->axis() << ", input shape size: " << input_shape_size; | |||
| MS_LOG(ERROR) << "Invalid axis " << GetAxis() << ", input shape size: " << input_shape_size; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| std::vector<int> output_shape(input->shape()); | |||
| if (argmin_prim->topK() == 1 && !argmin_prim->keepDims()) { | |||
| if (GetTopK() == 1 && !GetKeepDims()) { | |||
| output_shape.erase(output_shape.begin() + axis); | |||
| } else { | |||
| output_shape[axis] = argmin_prim->topK(); | |||
| output_shape[axis] = GetTopK(); | |||
| } | |||
| output->set_shape(output_shape); | |||
| @@ -39,11 +39,10 @@ constexpr int kBroadcastToInputNum = 1; | |||
| constexpr int kBroadcastToOutputNum = 1; | |||
| } // namespace | |||
| int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) { | |||
| MS_ASSERT(this->primitive != nullptr); | |||
| int BroadcastTo::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) { | |||
| if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) { | |||
| MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size(); | |||
| return 1; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto input = inputs.at(0); | |||
| outputs[0]->SetFormat(input->GetFormat()); | |||
| @@ -51,27 +50,26 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| std::vector<int32_t> dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(), | |||
| this->primitive->value_as_BroadcastTo()->dst_shape()->end()); | |||
| std::vector<int32_t> dst_shape(GetDstShape().begin(), GetDstShape().end()); | |||
| auto input_shape = input->shape(); | |||
| std::vector<int> shape(dst_shape.size()); | |||
| int input_shape_index = input_shape.size() - 1; | |||
| if (input_shape.size() > dst_shape.size()) { | |||
| MS_LOG(ERROR) << "input shape size " << input_shape.size() << " should <= broadcast to shape size " | |||
| << dst_shape.size() << "!"; | |||
| return 1; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| for (int i = dst_shape.size() - 1; i >= 0; --i) { | |||
| if (dst_shape[i] < 0) { | |||
| MS_LOG(ERROR) << "shape[" << i << "] = " << dst_shape[i] << " ] should be > 0!"; | |||
| return 1; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (input_shape_index >= 0) { | |||
| auto dim = input_shape[input_shape_index]; | |||
| if (dim != dst_shape[i] && dim != 1) { | |||
| MS_LOG(ERROR) << "Invalid broadcast shape!"; | |||
| return 1; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| } | |||
| shape[i] = dst_shape[i]; | |||
| @@ -45,14 +45,14 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| output->SetFormat(input->GetFormat()); | |||
| auto cast_prim = this->primitive->value_as_Cast(); | |||
| MS_ASSERT(cast_prim != nullptr); | |||
| output->set_data_type(static_cast<TypeId>(cast_prim->dstT())); | |||
| output->set_data_type(static_cast<TypeId>(GetDstT())); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| if (input->data_type() != cast_prim->srcT()) { | |||
| if (input->data_type() != GetSrcT()) { | |||
| MS_LOG(ERROR) << "input dataType is error"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| @@ -55,10 +55,10 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto concat_prim = this->primitive->value_as_Concat(); | |||
| MS_ASSERT(concat_prim != nullptr); | |||
| auto input0_shape = inputs_.at(0)->shape(); | |||
| int axis = concat_prim->axis() < 0 ? concat_prim->axis() + input0_shape.size() : concat_prim->axis(); | |||
| int axis = GetAxis() < 0 ? GetAxis() + input0_shape.size() : GetAxis(); | |||
| if (axis < 0 || axis >= input0_shape.size()) { | |||
| MS_LOG(ERROR) << "Invalid axis: " << axis; | |||
| return RET_PARAM_INVALID; | |||
| @@ -41,7 +41,6 @@ constexpr int kCropOutputNum = 1; | |||
| constexpr int kCropInputNum = 2; | |||
| } // namespace | |||
| int Crop::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) { | |||
| MS_ASSERT(this->primitive != nullptr); | |||
| if (outputs.size() != kCropOutputNum || inputs.size() != kCropInputNum) { | |||
| MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); | |||
| return RET_PARAM_INVALID; | |||
| @@ -139,7 +139,6 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported pad mode for deconv"; | |||
| } | |||
| std::vector<int> out_shape = {output_n, output_h, output_w, output_c}; | |||
| output->set_shape(out_shape); | |||
| return 0; | |||
| @@ -154,7 +154,7 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, | |||
| out_shape.at(2) = output_w; | |||
| if (GetChannelMultiplier() * input_channel != weight->shape()[0]) { | |||
| MS_LOG(ERROR) << "Conv depthwise only support group equals output channel."; | |||
| return 1; | |||
| return RET_ERROR; | |||
| } | |||
| out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel | |||
| @@ -42,13 +42,13 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve | |||
| MS_ASSERT(this->primitive != nullptr); | |||
| if (outputs.size() != kDepthToSpaceOutputNum || inputs.size() != kDepthToSpaceInputNum) { | |||
| MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); | |||
| return 1; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto input = inputs.at(0); | |||
| if (input->GetFormat() != schema::Format_NHWC) { | |||
| MS_LOG(ERROR) << "depth_to_space only support NHWC now!"; | |||
| return 1; | |||
| return RET_FORMAT_ERR; | |||
| } | |||
| outputs[0]->set_data_type(input->data_type()); | |||
| outputs[0]->SetFormat(input->GetFormat()); | |||
| @@ -58,14 +58,14 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve | |||
| auto input_shape = input->shape(); | |||
| if (input_shape.size() != kDimension_4d) { | |||
| MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; | |||
| return 1; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| int32_t block_size = GetBlockSize(); | |||
| if (input_shape[NHWC_C] % (block_size * block_size) != 0 || input_shape[NHWC_C] == 0) { | |||
| MS_LOG(ERROR) << "input dimension c size " << input_shape[NHWC_C] << " should be mulitple of block_size(" | |||
| << block_size << ") * block_size)!"; | |||
| return 1; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| std::vector<int32_t> output_shape(input_shape.size()); | |||
| output_shape[NHWC_N] = input_shape[NHWC_N]; | |||
| @@ -47,8 +47,7 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto expand_dims_prim = this->primitive->value_as_ExpandDims(); | |||
| int dim = expand_dims_prim->dim(); | |||
| int dim = GetDim(); | |||
| if (dim < 0) { | |||
| dim += input->shape().size() + 1; | |||
| } | |||
| @@ -58,10 +58,10 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto gather_prim = this->primitive->value_as_Gather(); | |||
| MS_ASSERT(gather_prim != nullptr); | |||
| int axis = gather_prim->axis(); | |||
| int batch_dims = gather_prim->batchDims(); | |||
| int axis = GetAxis(); | |||
| int batch_dims = GetBatchDims(); | |||
| if (axis < 0) { | |||
| axis += input->shape().size(); | |||
| } | |||
| @@ -58,18 +58,18 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| MS_LOG(ERROR) << "OpLstm input dims should be 3."; | |||
| return RET_ERROR; | |||
| } | |||
| auto lstm_prim = this->primitive->value_as_Lstm(); | |||
| int hidden_size = w_shape[1] / 4; | |||
| // set output | |||
| std::vector<int> out_shape(in_shape); | |||
| out_shape[2] = hidden_size; | |||
| if (lstm_prim->bidirection()) { | |||
| if (GetBidirection()) { | |||
| out_shape.insert(out_shape.begin() + 1, 2); | |||
| } | |||
| output->set_shape(out_shape); | |||
| // set hidden state, cell state | |||
| std::vector<int> state_shape(in_shape); | |||
| state_shape[0] = lstm_prim->bidirection() ? 2 : 1; | |||
| state_shape[0] = GetBidirection() ? 2 : 1; | |||
| state_shape[2] = hidden_size; | |||
| outputs_[1]->set_shape(state_shape); | |||
| outputs_[2]->set_shape(state_shape); | |||
| @@ -62,11 +62,11 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| } | |||
| auto matmul_prim = this->primitive->value_as_MatMul(); | |||
| if (matmul_prim->transposeA()) { | |||
| if (GetTransposeA()) { | |||
| std::swap(a_shape[a_shape.size() - 1], a_shape[a_shape.size() - 2]); | |||
| } | |||
| if (matmul_prim->transposeB()) { | |||
| if (GetTransposeB()) { | |||
| std::swap(b_shape[b_shape.size() - 1], b_shape[b_shape.size() - 2]); | |||
| } | |||
| std::vector<int> c_shape(a_shape); | |||
| @@ -58,12 +58,12 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| if (this->primitive == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto mean_prim = this->primitive->value_as_Mean(); | |||
| bool keep_dims = static_cast<bool>(mean_prim->keepDims()); | |||
| bool keep_dims = static_cast<bool>(GetKeepDims()); | |||
| std::vector<int> in_shape = input->shape(); | |||
| std::vector<int> out_shape; | |||
| const auto &axes = mean_prim->axis(); | |||
| auto num_axes = axes->size(); | |||
| const auto &axes = GetAxis(); | |||
| auto num_axes = axes.size(); | |||
| // reduce on all axes | |||
| if (num_axes == 0) { | |||
| if (keep_dims) { | |||
| @@ -79,7 +79,7 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| for (size_t i = 0; i < in_shape.size(); i++) { | |||
| bool reduce_axis = false; | |||
| for (int idx = 0; idx < num_axes; ++idx) { | |||
| if (static_cast<size_t>((*axes)[idx]) == i) { | |||
| if (static_cast<size_t>(axes[idx]) == i) { | |||
| reduce_axis = true; | |||
| break; | |||
| } | |||
| @@ -37,11 +37,8 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor: | |||
| if (this->primitive == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto one_hot_prim = this->primitive->value_as_OneHot(); | |||
| if (one_hot_prim == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| int axis = one_hot_prim->axis(); | |||
| int axis = GetAxis(); | |||
| // indices, depth, on_value, off_value | |||
| if (inputs.size() != kOneHotInputNum) { | |||
| MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum; | |||
| @@ -49,14 +49,9 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te | |||
| if (this->primitive == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto pad_prim = this->primitive->value_as_Pad(); | |||
| if (pad_prim == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto paddings = pad_prim->paddings(); | |||
| if (paddings == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto paddings = GetPaddings(); | |||
| auto input = inputs.front(); | |||
| if (input == nullptr) { | |||
| return RET_NULL_PTR; | |||
| @@ -75,7 +70,7 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te | |||
| MS_ASSERT(input->shape().size() <= kInputRank); | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| auto paddings_index = i + kInputRank - input_shape.size(); | |||
| auto shape = input_shape[i] + (*paddings)[2 * paddings_index] + (*paddings)[2 * paddings_index + 1]; | |||
| auto shape = input_shape[i] + paddings[2 * paddings_index] + paddings[2 * paddings_index + 1]; | |||
| output_shape.push_back(shape); | |||
| } | |||
| @@ -97,7 +97,6 @@ constexpr int kPriorBoxW = 1; | |||
| constexpr int kPriorBoxC = 2; | |||
| } // namespace | |||
| int PriorBox::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) { | |||
| auto param = this->primitive->value_as_PriorBox(); | |||
| MS_ASSERT(param != nullptr); | |||
| auto input = inputs_.at(0); | |||
| MS_ASSERT(input != nullptr); | |||
| @@ -109,20 +108,20 @@ int PriorBox::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens | |||
| return RET_OK; | |||
| } | |||
| std::vector<float> different_aspect_ratios{1.0f}; | |||
| auto aspect_ratios = param->aspect_ratios(); | |||
| auto aspect_ratios = GetAspectRatios(); | |||
| MS_ASSERT(aspect_ratios != nullptr); | |||
| for (auto i = 0; i < aspect_ratios->size(); i++) { | |||
| float ratio = (*aspect_ratios)[i]; | |||
| for (auto i = 0; i < aspect_ratios.size(); i++) { | |||
| float ratio = aspect_ratios[i]; | |||
| bool exist = std::any_of(different_aspect_ratios.begin(), different_aspect_ratios.end(), | |||
| [&](float v) { return abs(ratio - v) < 1e-6; }); | |||
| if (!exist) { | |||
| different_aspect_ratios.emplace_back(ratio); | |||
| if (param->flip()) { | |||
| if (GetFlip()) { | |||
| different_aspect_ratios.emplace_back(1.0f / ratio); | |||
| } | |||
| } | |||
| } | |||
| int32_t num_priors_box = param->min_sizes()->size() * different_aspect_ratios.size() + param->max_sizes()->size(); | |||
| int32_t num_priors_box = GetMinSizes().size() * different_aspect_ratios.size() + GetMaxSizes().size(); | |||
| int32_t h = input->Height() * input->Width() * num_priors_box * kPriorBoxPoints; | |||
| std::vector<int> output_shape{kPriorBoxN, h, kPriorBoxW, kPriorBoxC}; | |||
| output->set_shape(output_shape); | |||
| @@ -40,9 +40,8 @@ int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| auto param = primitive->value_as_QuantDTypeCast(); | |||
| MS_ASSERT(input->data_type() == param->srcT); | |||
| output->set_data_type(static_cast<TypeId>(param->dstT())); | |||
| output->set_data_type(static_cast<TypeId>(GetDstT())); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| @@ -48,7 +48,7 @@ int Range::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor: | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| auto range_prim = this->primitive->value_as_Range(); | |||
| MS_ASSERT(range_prim != nullptr); | |||
| output->set_data_type(input->data_type()); | |||
| @@ -57,7 +57,7 @@ int Range::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor: | |||
| return RET_OK; | |||
| } | |||
| int shape_size = std::ceil(static_cast<float>(range_prim->limit() - range_prim->start()) / range_prim->delta()); | |||
| int shape_size = std::ceil(static_cast<float>(GetLimit() - GetStart()) / GetDelta()); | |||
| std::vector<int> in_shape(1); | |||
| in_shape.push_back(shape_size); | |||
| output->set_shape(in_shape); | |||
| @@ -62,12 +62,12 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| if (this->primitive == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto reduce_prim = this->primitive->value_as_Reduce(); | |||
| bool keep_dims = static_cast<bool>(reduce_prim->keepDims()); | |||
| bool keep_dims = static_cast<bool>(GetKeepDims()); | |||
| std::vector<int> in_shape = input->shape(); | |||
| std::vector<int> out_shape; | |||
| const auto &axes = reduce_prim->axes(); | |||
| auto num_axes = axes->size(); | |||
| const auto &axes = GetAxes(); | |||
| auto num_axes = axes.size(); | |||
| // reduce on all axes | |||
| if (num_axes == 0) { | |||
| if (keep_dims) { | |||
| @@ -83,7 +83,7 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| for (size_t i = 0; i < in_shape.size(); i++) { | |||
| bool reduce_axis = false; | |||
| for (int idx = 0; idx < num_axes; ++idx) { | |||
| if (static_cast<size_t>((*axes)[idx]) == i || static_cast<size_t>((*axes)[idx] + in_shape.size()) == i) { | |||
| if (static_cast<size_t>(axes[idx]) == i || static_cast<size_t>(axes[idx] + in_shape.size()) == i) { | |||
| reduce_axis = true; | |||
| break; | |||
| } | |||
| @@ -61,9 +61,9 @@ int ROIPooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto ROIPooling = this->primitive->value_as_ROIPooling(); | |||
| auto new_h = ROIPooling->pooledH(); | |||
| auto new_w = ROIPooling->pooledW(); | |||
| auto new_h = GetPooledH(); | |||
| auto new_w = GetPooledW(); | |||
| auto shape_data = roi->shape(); | |||
| std::vector<int> output_shape; | |||
| output_shape.push_back(shape_data[0]); | |||
| @@ -55,12 +55,10 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso | |||
| } | |||
| auto in_shape = in_tensor->shape(); | |||
| std::vector<int> out_shape; | |||
| // todo: getAxis | |||
| auto squeeze_prim = this->primitive->value_as_Squeeze(); | |||
| MS_EXCEPTION_IF_NULL(squeeze_prim); | |||
| auto axis = squeeze_prim->axis(); | |||
| auto axis = GetAxis(); | |||
| std::vector<int> axes_; | |||
| for (auto iter = axis->begin(); iter != axis->end(); iter++) { | |||
| for (auto iter = axis.begin(); iter != axis.end(); iter++) { | |||
| axes_.push_back(*iter); | |||
| } | |||
| if (axes_.size() == 0) { | |||
| @@ -62,11 +62,11 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:: | |||
| return RET_OK; | |||
| } | |||
| auto input_shape = input->shape(); | |||
| auto stack_prim = this->primitive->value_as_Stack(); | |||
| std::vector<int32_t> output_shape = input_shape; | |||
| int axis = stack_prim->axis() < 0 ? stack_prim->axis() + input_shape.size() : stack_prim->axis(); | |||
| int axis = GetAxis() < 0 ? GetAxis() + input_shape.size() : GetAxis(); | |||
| if (axis < 0 || axis > input_shape.size()) { | |||
| MS_LOG(ERROR) << "Invalid axis " << stack_prim->axis(); | |||
| MS_LOG(ERROR) << "Invalid axis " << GetAxis(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| schema::Format input0_format = input->GetFormat(); | |||
| @@ -174,10 +174,6 @@ int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve | |||
| std::vector<int> output_shape; | |||
| ndim_ = static_cast<int>(GetBegin().size()); | |||
| MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->end()->size())); | |||
| MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->stride()->size())); | |||
| MS_ASSERT(ndim_ == static_cast<int>(input_shape.size())); | |||
| for (int i = 0; i < ndim_; i++) { | |||
| in_shape_.emplace_back(input_shape.at(i)); | |||
| begins_.emplace_back((GetBegin())[i]); | |||
| @@ -53,10 +53,9 @@ int TopK::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto topk_prim = this->primitive->value_as_TopK(); | |||
| MS_ASSERT(topk_prim != nullptr); | |||
| auto out_shape = input->shape(); | |||
| out_shape[out_shape.size() - 1] = topk_prim->k(); | |||
| out_shape[out_shape.size() - 1] = GetK(); | |||
| output0->set_shape(out_shape); | |||
| output1->set_shape(out_shape); | |||
| return RET_OK; | |||
| @@ -53,11 +53,11 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto unsqueeze_prim = this->primitive->value_as_Unsqueeze(); | |||
| auto dims = unsqueeze_prim->axis()->data(); | |||
| auto dims = GetAxis().data(); | |||
| auto in_shape = input->shape(); | |||
| auto in_rank = in_shape.size(); | |||
| auto dim_rank = unsqueeze_prim->axis()->size(); | |||
| auto dim_rank = GetAxis().size(); | |||
| std::vector<int> out_shape; | |||
| if (dim_rank == 0) { | |||
| for (auto d : in_shape) { | |||
| @@ -38,10 +38,10 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor | |||
| auto input = inputs.at(0); | |||
| MS_ASSERT(input != nullptr); | |||
| auto input_shape = input->shape(); | |||
| auto prim = this->primitive->value_as_Unstack(); | |||
| int axis = prim->axis() < 0 ? prim->axis() + input_shape.size() : prim->axis(); | |||
| int axis = GetAxis() < 0 ? GetAxis() + input_shape.size() : GetAxis(); | |||
| if (axis < 0 || axis >= input_shape.size()) { | |||
| MS_LOG(ERROR) << "Invalid axis " << prim->axis(); | |||
| MS_LOG(ERROR) << "Invalid axis " << GetAxis(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| for (auto &out : outputs) { | |||