Merge pull request !4665 from chenjianping/lite_dev3tags/v0.7.0-beta
| @@ -43,6 +43,11 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T | |||
| MS_LOG(ERROR) << "input size" << inputs.size() << " is error!"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_data_type(input->data_type()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| for (int i = 1; i < inputs.size(); ++i) { | |||
| if (inputs.at(i)->shape() != inputs.at(0)->shape()) { | |||
| MS_LOG(ERROR) << "AddN inputs shape is not equal!"; | |||
| @@ -53,9 +58,8 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| } | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_shape(input->shape()); | |||
| output->set_data_type(input->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -55,6 +55,12 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { | |||
| MS_LOG(ERROR) << "tensor number is error."; | |||
| } | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_data_type(input->data_type()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto argmax_prim = this->primitive->value_as_ArgMax(); | |||
| std::vector<int> output_shape(input->shape()); | |||
| auto input_shape_size = input->shape().size(); | |||
| @@ -68,9 +74,8 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| } else { | |||
| output_shape[axis] = argmax_prim->topK(); | |||
| } | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_shape(output_shape); | |||
| output->set_data_type(input->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -55,6 +55,11 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) { | |||
| MS_LOG(ERROR) << "tensor number is error."; | |||
| } | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_data_type(input->data_type()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto argmin_prim = this->primitive->value_as_ArgMin(); | |||
| auto input_shape_size = input->shape().size(); | |||
| int axis = argmin_prim->axis() < 0 ? argmin_prim->axis() + input_shape_size : argmin_prim->axis(); | |||
| @@ -68,9 +73,8 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| } else { | |||
| output_shape[axis] = argmin_prim->topK(); | |||
| } | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_shape(output_shape); | |||
| output->set_data_type(input->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -46,6 +46,11 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec | |||
| return 1; | |||
| } | |||
| auto input = inputs.at(0); | |||
| outputs[0]->SetFormat(input->GetFormat()); | |||
| outputs[0]->set_data_type(input->data_type()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| std::vector<int32_t> dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(), | |||
| this->primitive->value_as_BroadcastTo()->dst_shape()->end()); | |||
| auto input_shape = input->shape(); | |||
| @@ -72,10 +77,8 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec | |||
| shape[i] = dst_shape[i]; | |||
| --input_shape_index; | |||
| } | |||
| outputs[0]->SetFormat(input->GetFormat()); | |||
| outputs[0]->set_shape(shape); | |||
| outputs[0]->set_data_type(input->data_type()); | |||
| return 0; | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -44,8 +44,14 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| MS_LOG(ERROR) << "tensor number is error."; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| output->SetFormat(input->GetFormat()); | |||
| auto cast_prim = this->primitive->value_as_Cast(); | |||
| MS_ASSERT(cast_prim != nullptr); | |||
| output->set_data_type(static_cast<TypeId>(cast_prim->dstT())); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| if (input->data_type() != cast_prim->srcT()) { | |||
| MS_LOG(ERROR) << "input dataType is error"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| @@ -54,13 +60,8 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| MS_LOG(ERROR) << "Unsupported input data type " << input->data_type(); | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| if (cast_prim->dstT() != kNumberTypeFloat && cast_prim->dstT() != kNumberTypeFloat32) { | |||
| MS_LOG(ERROR) << "Invalid output datatype " << cast_prim->dstT(); | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_shape(input->shape()); | |||
| output->set_data_type(TypeId::kNumberTypeFloat32); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -50,16 +50,19 @@ int ConstantOfShape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect | |||
| return RET_ERROR; | |||
| } | |||
| auto in_tensor = inputs_.front(); | |||
| auto in_data = reinterpret_cast<int *>(in_tensor->Data()); | |||
| auto out_tensor = outputs_.front(); | |||
| out_tensor->set_data_type(kNumberTypeFloat32); | |||
| out_tensor->SetFormat(in_tensor->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto in_data = reinterpret_cast<int *>(in_tensor->Data()); | |||
| int size = in_tensor->ElementsNum(); | |||
| std::vector<int> out_shape(size); | |||
| for (int i = 0; i < size; ++i) { | |||
| out_shape[i] = in_data[i]; | |||
| } | |||
| out_tensor->set_shape(out_shape); | |||
| out_tensor->set_data_type(kNumberTypeFloat32); | |||
| out_tensor->SetFormat(in_tensor->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| @@ -46,9 +46,12 @@ int Crop::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T | |||
| MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| outputs[0]->set_shape(inputs[1]->shape()); | |||
| outputs[0]->SetFormat(inputs[0]->GetFormat()); | |||
| outputs[0]->set_data_type(inputs[0]->data_type()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| outputs[0]->set_shape(inputs[1]->shape()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -103,7 +103,11 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto | |||
| MS_ASSERT(weight != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_data_type(input->data_type()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| int32_t input_h = input->Height(); | |||
| int32_t input_w = input->Width(); | |||
| @@ -138,8 +142,6 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto | |||
| std::vector<int> out_shape = {output_n, output_h, output_w, output_c}; | |||
| output->set_shape(out_shape); | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_data_type(input->data_type()); | |||
| return 0; | |||
| } | |||
| } // namespace lite | |||
| @@ -126,7 +126,11 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, | |||
| MS_ASSERT(weight != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_data_type(input->data_type()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto in_shape = input->shape(); | |||
| int input_h = in_shape.at(1); | |||
| int input_w = in_shape.at(2); | |||
| @@ -155,8 +159,6 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, | |||
| out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel | |||
| output->set_shape(out_shape); | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_data_type(input->data_type()); | |||
| return 0; | |||
| } | |||
| } // namespace lite | |||
| @@ -50,6 +50,11 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve | |||
| MS_LOG(ERROR) << "depth_to_space only support NHWC now!"; | |||
| return 1; | |||
| } | |||
| outputs[0]->set_data_type(input->data_type()); | |||
| outputs[0]->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto input_shape = input->shape(); | |||
| if (input_shape.size() != kDimension_4d) { | |||
| MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; | |||
| @@ -68,10 +73,7 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve | |||
| output_shape[NHWC_W] = input_shape[NHWC_W] * block_size; | |||
| output_shape[NHWC_C] = input_shape[NHWC_C] / (block_size * block_size); | |||
| outputs[0]->set_shape(output_shape); | |||
| outputs[0]->set_data_type(input->data_type()); | |||
| outputs[0]->SetFormat(input->GetFormat()); | |||
| return 0; | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -120,7 +120,11 @@ int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, | |||
| MS_ASSERT(weight != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_data_type(input->data_type()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto in_shape = input->shape(); | |||
| int input_h = in_shape.at(1); | |||
| int input_w = in_shape.at(2); | |||
| @@ -158,8 +162,6 @@ int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, | |||
| out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel | |||
| output->set_shape(out_shape); | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_data_type(input->data_type()); | |||
| return 0; | |||
| } | |||
| } // namespace lite | |||
| @@ -46,6 +46,12 @@ int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect | |||
| MS_ASSERT(ids != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->SetFormat(params_->GetFormat()); | |||
| output->set_data_type(params_->data_type()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto embedding_shape = params_->shape(); | |||
| embedding_shape.erase(embedding_shape.begin()); | |||
| std::vector<int> output_shape(ids->shape()); | |||
| @@ -61,7 +67,6 @@ int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect | |||
| } | |||
| } | |||
| output->set_shape(output_shape); | |||
| output->set_data_type(params_->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -42,6 +42,11 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te | |||
| if (outputs_.size() != kSingleNum) { | |||
| MS_LOG(ERROR) << "output size is invalid"; | |||
| } | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto expand_dims_prim = this->primitive->value_as_ExpandDims(); | |||
| int dim = expand_dims_prim->dim(); | |||
| if (dim < 0) { | |||
| @@ -54,8 +59,6 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te | |||
| auto out_shape = input->shape(); | |||
| out_shape.insert(out_shape.begin() + dim, 1, 1); | |||
| output->set_shape(out_shape); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -45,6 +45,11 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto fill_prim = this->primitive->value_as_Fill(); | |||
| if (fill_prim == nullptr) { | |||
| MS_LOG(ERROR) << "Fill primitive is null!"; | |||
| @@ -53,8 +58,6 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| std::vector<int> output_shape; | |||
| (void)output_shape.insert(output_shape.begin(), fill_prim->dims()->begin(), fill_prim->dims()->end()); | |||
| output->set_shape(output_shape); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -31,6 +31,13 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso | |||
| MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size(); | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto input_shape = input->shape(); | |||
| std::vector<int> output_shape(2); | |||
| output_shape[0] = input_shape[0]; | |||
| @@ -39,8 +46,6 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso | |||
| output_shape[1] *= input_shape[i]; | |||
| } | |||
| output->set_shape(output_shape); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -51,7 +51,11 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_, | |||
| MS_ASSERT(input1 != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_data_type(input0->data_type()); | |||
| output->SetFormat(input0->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| if ((GetHasBias() && inputs_.size() != kMultiNum) || (!GetHasBias() && inputs_.size() != kDoubleNum)) { | |||
| MS_LOG(ERROR) << "Input tensors num error"; | |||
| return 1; | |||
| @@ -78,8 +82,6 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_, | |||
| out_shape.resize(GetAxis() + 1); | |||
| out_shape[GetAxis()] = input1->shape()[0]; | |||
| output->set_shape(out_shape); | |||
| output->set_data_type(input0->data_type()); | |||
| output->SetFormat(input0->GetFormat()); | |||
| return 0; | |||
| } | |||
| @@ -46,6 +46,12 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens | |||
| MS_ASSERT(indices != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto in_shape = input->shape(); | |||
| int in_rank = in_shape.size(); | |||
| auto indices_shape = indices->shape(); | |||
| @@ -63,8 +69,6 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens | |||
| out_shape.emplace_back(in_shape[i]); | |||
| } | |||
| output->set_shape(out_shape); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -44,6 +44,14 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| MS_ASSERT(input0 != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| for (int i = 0; i < kLstmOutputNum; i++) { | |||
| outputs_[i]->set_data_type(input->data_type()); | |||
| outputs_[i]->SetFormat(input->GetFormat()); | |||
| } | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| std::vector<int> in_shape = input->shape(); | |||
| std::vector<int> w_shape = weight_i->shape(); // layer, hidden_size * 4, input_size | |||
| if (in_shape.size() != 3 || w_shape.size() != 3) { | |||
| @@ -65,10 +73,7 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| state_shape[2] = hidden_size; | |||
| outputs_[1]->set_shape(state_shape); | |||
| outputs_[2]->set_shape(state_shape); | |||
| for (int i = 0; i < kLstmOutputNum; i++) { | |||
| outputs_[i]->set_data_type(input->data_type()); | |||
| outputs_[i]->SetFormat(input->GetFormat()); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -43,6 +43,13 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| MS_ASSERT(input1 != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_data_type(input0->data_type()); | |||
| output->SetFormat(input0->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| std::vector<int> a_shape = input0->shape(); | |||
| std::vector<int> b_shape = input1->shape(); | |||
| if (a_shape.size() < 2 || b_shape.size() < 2) { | |||
| @@ -65,8 +72,6 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| std::vector<int> c_shape(a_shape); | |||
| c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1]; | |||
| output->set_shape(c_shape); | |||
| output->set_data_type(input0->data_type()); | |||
| output->SetFormat(input0->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -50,6 +50,11 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| if (input == nullptr || output == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| if (this->primitive == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| @@ -88,8 +93,6 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| } | |||
| } | |||
| output->set_shape(out_shape); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -25,6 +25,11 @@ int Nchw2Nhwc::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->SetFormat(schema::Format_NHWC); | |||
| output->set_data_type(input->data_type()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| std::vector<int> nchw_shape = input->shape(); | |||
| if (nchw_shape.size() != 4) { | |||
| output->set_shape(nchw_shape); | |||
| @@ -36,8 +41,6 @@ int Nchw2Nhwc::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect | |||
| nhwc_shape[NHWC_C] = nchw_shape[NCHW_C]; | |||
| output->set_shape(nhwc_shape); | |||
| } | |||
| output->SetFormat(schema::Format_NHWC); | |||
| output->set_data_type(input->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -25,6 +25,11 @@ int Nhwc2Nchw::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->SetFormat(schema::Format_NCHW); | |||
| output->set_data_type(input->data_type()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| std::vector<int> nhwc_shape = input->shape(); | |||
| if (nhwc_shape.size() != 4) { | |||
| output->set_shape(nhwc_shape); | |||
| @@ -36,8 +41,6 @@ int Nhwc2Nchw::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect | |||
| nchw_shape[NCHW_W] = nhwc_shape[NHWC_W]; | |||
| output->set_shape(nchw_shape); | |||
| } | |||
| output->SetFormat(schema::Format_NCHW); | |||
| output->set_data_type(input->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -56,6 +56,19 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor: | |||
| if (input == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto on_value = inputs.at(2); | |||
| if (on_value == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto output = outputs.front(); | |||
| if (output == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| output->set_data_type(on_value->data_type()); | |||
| output->SetFormat(on_value->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| const auto input_shape = input->shape(); | |||
| int input_rank = static_cast<int>(input_shape.size()); | |||
| if (axis < 0) { | |||
| @@ -63,17 +76,7 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor: | |||
| } | |||
| std::vector<int> output_shape(input_shape); | |||
| output_shape.insert(output_shape.cbegin() + axis, *depth); | |||
| auto output = outputs.front(); | |||
| if (output == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| output->set_shape(output_shape); | |||
| auto on_value = inputs.at(2); | |||
| if (on_value == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| output->set_data_type(on_value->data_type()); | |||
| output->SetFormat(on_value->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -61,6 +61,15 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te | |||
| if (input == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto output = outputs.front(); | |||
| if (output == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_data_type(input->data_type()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto input_shape = input->shape(); | |||
| std::vector<int> output_shape; | |||
| MS_ASSERT(input->shape().size() <= kInputRank); | |||
| @@ -69,13 +78,8 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te | |||
| auto shape = input_shape[i] + (*paddings)[2 * paddings_index] + (*paddings)[2 * paddings_index + 1]; | |||
| output_shape.push_back(shape); | |||
| } | |||
| auto output = outputs.front(); | |||
| if (output == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_shape(output_shape); | |||
| output->set_data_type(input->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -95,6 +95,11 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(schema::Format_NHWC); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| int input_h = input->shape().at(1); | |||
| int input_w = input->shape().at(2); | |||
| auto pooling_prim = this->primitive->value_as_Pooling(); | |||
| @@ -137,9 +142,6 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso | |||
| input_shape.at(1) = output_h; | |||
| input_shape.at(2) = output_w; | |||
| output->set_shape(input_shape); | |||
| output->set_data_type(input->data_type()); | |||
| // todo: temp fix | |||
| output->SetFormat(schema::Format_NHWC); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -49,15 +49,19 @@ int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:: | |||
| } | |||
| auto output_tensor = outputs[0]; | |||
| MS_ASSERT(output_tensor != nullptr); | |||
| output_tensor->set_data_type(x_tensor->data_type()); | |||
| output_tensor->SetFormat(x_tensor->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| if (exp_tensor != nullptr) { | |||
| if (exp_tensor->shape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) { | |||
| MS_LOG(ERROR) << "Power inputs shape or type is not equal!"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| } | |||
| output_tensor->SetFormat(x_tensor->GetFormat()); | |||
| output_tensor->set_shape(x_tensor->shape()); | |||
| output_tensor->set_data_type(x_tensor->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -99,6 +99,15 @@ constexpr int kPriorBoxC = 2; | |||
| int PriorBox::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) { | |||
| auto param = this->primitive->value_as_PriorBox(); | |||
| MS_ASSERT(param != nullptr); | |||
| auto input = inputs_.at(0); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.at(0); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_data_type(kNumberTypeFloat32); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| std::vector<float> different_aspect_ratios{1.0f}; | |||
| auto aspect_ratios = param->aspect_ratios(); | |||
| MS_ASSERT(aspect_ratios != nullptr); | |||
| @@ -114,15 +123,9 @@ int PriorBox::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens | |||
| } | |||
| } | |||
| int32_t num_priors_box = param->min_sizes()->size() * different_aspect_ratios.size() + param->max_sizes()->size(); | |||
| auto input = inputs_.at(0); | |||
| MS_ASSERT(input != nullptr); | |||
| int32_t h = input->Height() * input->Width() * num_priors_box * kPriorBoxPoints; | |||
| std::vector<int> output_shape{kPriorBoxN, h, kPriorBoxW, kPriorBoxC}; | |||
| auto output = outputs_.at(0); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_shape(output_shape); | |||
| output->set_data_type(kNumberTypeFloat32); | |||
| output->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -40,11 +40,14 @@ int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_shape(input->shape()); | |||
| auto param = primitive->value_as_QuantDTypeCast(); | |||
| MS_ASSERT(input->data_type() == param->srcT); | |||
| output->set_data_type(static_cast<TypeId>(param->dstT())); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| output->set_shape(input->shape()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -50,12 +50,18 @@ int Range::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor: | |||
| MS_ASSERT(output != nullptr); | |||
| auto range_prim = this->primitive->value_as_Range(); | |||
| MS_ASSERT(range_prim != nullptr); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| int shape_size = std::ceil(static_cast<float>(range_prim->limit() - range_prim->start()) / range_prim->delta()); | |||
| std::vector<int> in_shape(1); | |||
| in_shape.push_back(shape_size); | |||
| output->set_shape(in_shape); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -25,10 +25,13 @@ int Rank::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| std::vector<int> in_shape(1, 1); | |||
| output->set_shape(in_shape); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| std::vector<int> in_shape(1, 1); | |||
| output->set_shape(in_shape); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -66,6 +66,11 @@ int Resize::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector< | |||
| if (output == nullptr) { | |||
| return 1; | |||
| } | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto new_height = GetNewHeight(); | |||
| auto new_width = GetNewWidth(); | |||
| @@ -75,10 +80,8 @@ int Resize::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector< | |||
| output_shape.push_back(new_width); | |||
| output_shape.push_back(input->Channel()); | |||
| output->set_shape(output_shape); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| return 0; | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -52,9 +52,13 @@ int ReverseSequence::InferShape(std::vector<tensor::Tensor *> inputs, std::vecto | |||
| auto output = outputs.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_shape(input->shape()); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| output->set_shape(input->shape()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -56,6 +56,11 @@ int ROIPooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te | |||
| if (output == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto ROIPooling = this->primitive->value_as_ROIPooling(); | |||
| auto new_h = ROIPooling->pooledH(); | |||
| auto new_w = ROIPooling->pooledW(); | |||
| @@ -66,8 +71,6 @@ int ROIPooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te | |||
| output_shape.push_back(new_w); | |||
| output_shape.push_back(input->Channel()); | |||
| output->set_shape(output_shape); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -51,11 +51,14 @@ int ScatterND::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten | |||
| return RET_ERROR; | |||
| } | |||
| auto output = outputs_.front(); | |||
| output->set_data_type(update->data_type()); | |||
| output->SetFormat(update->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto shape_data = reinterpret_cast<int *>(shape->Data()); | |||
| std::vector<int> out_shape(shape_data, shape_data + shape->DataSize()); | |||
| output->set_shape(out_shape); | |||
| output->set_data_type(update->data_type()); | |||
| output->SetFormat(update->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -63,6 +63,11 @@ int SpaceToBatch::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve | |||
| MS_LOG(ERROR) << "space_to_batch only support NHWC now!"; | |||
| return 1; | |||
| } | |||
| outputs[0]->set_data_type(input->data_type()); | |||
| outputs[0]->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto input_shape = input->shape(); | |||
| if (input_shape.size() != kDimension_4d) { | |||
| MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; | |||
| @@ -106,8 +111,7 @@ int SpaceToBatch::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve | |||
| output_shape[NHWC_W] = input_shape[NHWC_W] / block_sizes_[NHWC_H]; | |||
| output_shape[NHWC_C] = input_shape[NHWC_C]; | |||
| outputs[0]->set_shape(output_shape); | |||
| outputs[0]->set_data_type(input->data_type()); | |||
| return 0; | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -51,6 +51,11 @@ int SpaceToDepth::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve | |||
| MS_LOG(ERROR) << "space_to_depth only support NHWC now!"; | |||
| return 1; | |||
| } | |||
| outputs[0]->SetFormat(input->GetFormat()); | |||
| outputs[0]->set_data_type(input->data_type()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto input_shape = input->shape(); | |||
| if (input_shape.size() != kDimension_4d) { | |||
| MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; | |||
| @@ -69,8 +74,7 @@ int SpaceToDepth::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve | |||
| output_shape[NHWC_W] = input_shape[NHWC_W] / block_size; | |||
| output_shape[NHWC_C] = input_shape[NHWC_C] * (block_size * block_size); | |||
| outputs[0]->set_shape(output_shape); | |||
| outputs[0]->set_data_type(input->data_type()); | |||
| return 0; | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -66,6 +66,13 @@ int Split::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor: | |||
| MS_LOG(ERROR) << "outputs number is not equal to " << number_split; | |||
| return RET_ERROR; | |||
| } | |||
| for (int i = 0; i < number_split; ++i) { | |||
| outputs_[i]->set_data_type(input->data_type()); | |||
| outputs_[i]->SetFormat(input->GetFormat()); | |||
| } | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| int split_dim = spilt_prim->splitDim(); | |||
| std::vector<int> input_shape = input->shape(); | |||
| std::vector<int> size_split; | |||
| @@ -48,6 +48,11 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso | |||
| return -1; | |||
| } | |||
| auto *in_tensor = inputs_.front(); | |||
| outputs_.front()->set_data_type(in_tensor->data_type()); | |||
| outputs_.front()->SetFormat(in_tensor->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto in_shape = in_tensor->shape(); | |||
| std::vector<int> out_shape; | |||
| // todo: getAxis | |||
| @@ -77,8 +82,6 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso | |||
| } | |||
| } | |||
| outputs_.front()->set_shape(out_shape); | |||
| outputs_.front()->set_data_type(in_tensor->data_type()); | |||
| outputs_.front()->SetFormat(in_tensor->GetFormat()); | |||
| return 0; | |||
| } | |||
| } // namespace lite | |||
| @@ -56,6 +56,11 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:: | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto input = inputs.at(0); | |||
| outputs[0]->set_data_type(input->data_type()); | |||
| outputs[0]->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto input_shape = input->shape(); | |||
| auto stack_prim = this->primitive->value_as_Stack(); | |||
| std::vector<int32_t> output_shape = input_shape; | |||
| @@ -84,8 +89,6 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:: | |||
| } | |||
| output_shape.insert(output_shape.begin() + axis, inputs.size()); | |||
| outputs[0]->set_shape(output_shape); | |||
| outputs[0]->set_data_type(input->data_type()); | |||
| outputs[0]->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -164,6 +164,11 @@ int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto input = inputs.at(0); | |||
| outputs.front()->set_data_type(input->data_type()); | |||
| outputs[0]->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| MS_ASSERT(input != nullptr); | |||
| auto input_shape = input->shape(); | |||
| std::vector<int> output_shape; | |||
| @@ -214,8 +219,6 @@ int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve | |||
| output_shape = ApplyShrinkMask(output_shape); | |||
| outputs.front()->set_shape(output_shape); | |||
| outputs.front()->set_data_type(input->data_type()); | |||
| outputs[0]->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| @@ -40,6 +40,11 @@ int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto tile_prim = this->primitive->value_as_Tile(); | |||
| MS_ASSERT(tile_prim != nullptr); | |||
| std::vector<int> out_shape; | |||
| @@ -49,9 +54,8 @@ int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| int tmp = input->shape()[i] * multiples[i]; | |||
| out_shape.push_back(tmp); | |||
| } | |||
| output->SetFormat(input->GetFormat()); | |||
| output->set_shape(out_shape); | |||
| output->set_data_type(input->data_type()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -46,16 +46,19 @@ int TopK::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| MS_ASSERT(output0 != nullptr); | |||
| auto output1 = outputs_.at(1); | |||
| MS_ASSERT(output1 != nullptr); | |||
| output0->set_data_type(input->data_type()); | |||
| output0->SetFormat(input->GetFormat()); | |||
| output1->set_data_type(kNumberTypeInt32); | |||
| output1->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto topk_prim = this->primitive->value_as_TopK(); | |||
| MS_ASSERT(topk_prim != nullptr); | |||
| auto out_shape = input->shape(); | |||
| out_shape[out_shape.size() - 1] = topk_prim->k(); | |||
| output0->set_shape(out_shape); | |||
| output0->set_data_type(input->data_type()); | |||
| output0->SetFormat(input->GetFormat()); | |||
| output1->set_shape(out_shape); | |||
| output1->set_data_type(kNumberTypeInt32); | |||
| output1->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -42,12 +42,15 @@ int Unique::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||
| MS_ASSERT(output0 != nullptr); | |||
| auto &output1 = outputs_.at(1); | |||
| MS_ASSERT(output1 != nullptr); | |||
| output0->set_shape(input->shape()); | |||
| output0->set_data_type(input->data_type()); | |||
| output1->set_shape(input->shape()); | |||
| output1->set_data_type(kNumberTypeInt32); | |||
| output1->SetFormat(input->GetFormat()); | |||
| output0->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| output0->set_shape(input->shape()); | |||
| output1->set_shape(input->shape()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -44,6 +44,14 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor | |||
| MS_LOG(ERROR) << "Invalid axis " << prim->axis(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| for (auto &out : outputs) { | |||
| MS_ASSERT(out != nullptr); | |||
| out->set_data_type(input->data_type()); | |||
| out->SetFormat(input->GetFormat()); | |||
| } | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| std::vector<int> output_shape; | |||
| for (size_t i = 0; i < input_shape.size(); ++i) { | |||
| if (i != axis) { | |||
| @@ -53,8 +61,6 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor | |||
| for (auto &out : outputs) { | |||
| MS_ASSERT(out != nullptr); | |||
| out->set_shape(output_shape); | |||
| out->set_data_type(input->data_type()); | |||
| out->SetFormat(input->GetFormat()); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -53,6 +53,11 @@ int Where::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor: | |||
| auto input0 = inputs_.at(0); | |||
| auto input1 = inputs_.at(1); | |||
| auto input2 = inputs_.at(2); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| int num = input0->ElementsNum(); | |||
| int num1 = input1->ElementsNum(); | |||
| int num2 = input2->ElementsNum(); | |||
| @@ -85,8 +90,6 @@ int Where::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor: | |||
| auto output_shape = shape_tmp; | |||
| output_shape[axisout] = nummax; | |||
| outputs_[0]->set_shape(output_shape); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -29,10 +29,12 @@ int ZerosLike::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect | |||
| << ", output size: " << outputs_.size(); | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| output->set_shape(input->shape()); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| output->set_shape(input->shape()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -18,15 +18,29 @@ | |||
| #include <float.h> | |||
| int ArgCompareAscFp32(const void *a, const void *b) { | |||
| return ((ArgElement *)a)->data_.f_data_ - ((ArgElement *)b)->data_.f_data_; | |||
| float a_value = ((ArgElement *)a)->data_.f_data_; | |||
| float b_value = ((ArgElement *)b)->data_.f_data_; | |||
| if (b_value > a_value) { | |||
| return -1; | |||
| } | |||
| if (b_value < a_value) { | |||
| return 1; | |||
| } | |||
| return 0; | |||
| } | |||
| int ArgCompareDescFp32(const void *a, const void *b) { | |||
| // cmp funtion of qsort must return int type | |||
| auto b_value = ((ArgElement *)b)->data_.f_data_; | |||
| auto a_value = ((ArgElement *)a)->data_.f_data_; | |||
| int res = b_value > a_value ? 1 : -1; | |||
| return res; | |||
| float b_value = ((ArgElement *)b)->data_.f_data_; | |||
| float a_value = ((ArgElement *)a)->data_.f_data_; | |||
| if (b_value > a_value) { | |||
| return 1; | |||
| } | |||
| if (b_value < a_value) { | |||
| return -1; | |||
| } | |||
| return 0; | |||
| } | |||
| void ArgMaxDim0OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) { | |||