From: @mengyuanli Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -406,8 +406,8 @@ int Conv2D::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp | |||
| this->ConvInferShape(input_h, input_w, &output_h, &output_w); | |||
| std::vector<int> out_shape{input_tensor->shape()}; | |||
| out_shape.at(1) = output_h > 0 ? output_h : 1; | |||
| out_shape.at(2) = output_w > 0 ? output_w : 1; | |||
| out_shape.at(1) = output_h >= 0 ? output_h : 1; | |||
| out_shape.at(2) = output_w >= 0 ? output_w : 1; | |||
| out_shape.at(3) = weight_tensor->shape()[0]; | |||
| out_tensor->set_shape(out_shape); | |||
| @@ -66,14 +66,25 @@ PrimitiveC *MergeCreator(const schema::Primitive *primitive) { return PrimitiveC | |||
| Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator); | |||
| #endif | |||
| int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| MS_ASSERT(inputs_.size() == 2 * outputs_.size()); | |||
| if (!infer_flag()) { | |||
| return RET_INFER_INVALID; | |||
| InferStatus Merge::AbleToInfer(const std::vector<lite::Tensor *> &inputs) { | |||
| for (auto &input : inputs) { | |||
| if (input->shape().empty()) { | |||
| return HasZeroShape; | |||
| } | |||
| if (input->root_tensor() != nullptr && input->root_tensor()->data_c() != nullptr) { | |||
| continue; | |||
| } | |||
| if (input->data_c() == nullptr) { | |||
| return NotAble; | |||
| } | |||
| } | |||
| for (size_t i = 0; i < inputs_.size() / 2; i++) { | |||
| auto *input = inputs_[i]; | |||
| auto *output = outputs_[i]; | |||
| return Able; | |||
| } | |||
| int Merge::Infer(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs) { | |||
| for (size_t i = 0; i < inputs.size(); i++) { | |||
| auto *input = inputs[i]; | |||
| auto *output = outputs[i]; | |||
| if (input == nullptr) { | |||
| MS_LOG(ERROR) << "input tensor is nullptr"; | |||
| return RET_ERROR; | |||
| @@ -98,5 +109,35 @@ int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| MS_ASSERT(inputs_.size() == 2 * outputs_.size()); | |||
| for (size_t i = 0; i < outputs_.size(); ++i) { | |||
| outputs_[i]->set_data_type(inputs_[i]->data_type()); | |||
| } | |||
| if (!infer_flag()) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| std::vector<Tensor *> left_part_inputs{}; | |||
| left_part_inputs.assign(inputs_.begin(), inputs_.begin() + inputs_.size() / 2); | |||
| std::vector<Tensor *> right_part_inputs{}; | |||
| right_part_inputs.assign(inputs_.begin() + inputs_.size() / 2, inputs_.end()); | |||
| if (AbleToInfer(left_part_inputs) == Able) { | |||
| return Infer(left_part_inputs, outputs_); | |||
| } | |||
| if (AbleToInfer(right_part_inputs) == Able) { | |||
| return Infer(right_part_inputs, outputs_); | |||
| } | |||
| if (AbleToInfer(left_part_inputs) == HasZeroShape && AbleToInfer(right_part_inputs) == HasZeroShape) { | |||
| return Infer(left_part_inputs, outputs_); | |||
| } | |||
| return RET_INFER_INVALID; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -24,6 +24,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| enum InferStatus { Able, NotAble, HasZeroShape }; | |||
| class Merge : public PrimitiveC { | |||
| public: | |||
| @@ -37,6 +38,10 @@ class Merge : public PrimitiveC { | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| private: | |||
| static InferStatus AbleToInfer(const std::vector<lite::Tensor *> &inputs); | |||
| static int Infer(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -116,12 +116,12 @@ int Reshape::CalNewShape(const Tensor *in_tensor, std::vector<int> *out_shape) c | |||
| for (size_t i = 0; i < in_tensor->shape().size(); i++) { | |||
| in_shape_size *= in_tensor->shape().at(i); | |||
| } | |||
| int64_t inferIndex = -1; | |||
| size_t out_shapeSize = 1; | |||
| int64_t infer_index = -1; | |||
| size_t out_shape_size = 1; | |||
| for (size_t i = 0; i < out_shape->size(); i++) { | |||
| if (out_shape->at(i) == -1) { | |||
| if (inferIndex == -1) { | |||
| inferIndex = i; | |||
| if (infer_index == -1) { | |||
| infer_index = i; | |||
| } else { | |||
| MS_LOG(ERROR) << "output shape should has no more than one dim which need infer"; | |||
| return RET_INFER_ERR; | |||
| @@ -130,18 +130,23 @@ int Reshape::CalNewShape(const Tensor *in_tensor, std::vector<int> *out_shape) c | |||
| MS_LOG(ERROR) << "output shape dim should be non-negative"; | |||
| return RET_INFER_ERR; | |||
| } else if (out_shape->at(i) == 0) { | |||
| out_shape->at(i) = in_tensor->shape().at(i); | |||
| out_shapeSize *= out_shape->at(i); | |||
| if (in_tensor->ElementsNum() != 0) { | |||
| out_shape->at(i) = in_tensor->shape().at(i); | |||
| out_shape_size *= out_shape->at(i); | |||
| } else { | |||
| out_shape_size = 0; | |||
| break; | |||
| } | |||
| } else { | |||
| out_shapeSize *= out_shape->at(i); | |||
| out_shape_size *= out_shape->at(i); | |||
| } | |||
| } | |||
| if (inferIndex == -1 && out_shapeSize != in_shape_size) { | |||
| MS_LOG(ERROR) << "output shapeSize: " << out_shapeSize << " should be equal to input shapeSize: " << in_shape_size; | |||
| if (infer_index == -1 && out_shape_size != in_shape_size) { | |||
| MS_LOG(ERROR) << "output shapeSize: " << out_shape_size << " should be equal to input shapeSize: " << in_shape_size; | |||
| return RET_INFER_ERR; | |||
| } | |||
| if (inferIndex != -1) { | |||
| out_shape->at(inferIndex) = in_shape_size / out_shapeSize; | |||
| if (infer_index != -1) { | |||
| out_shape->at(infer_index) = in_shape_size / out_shape_size; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -39,7 +39,7 @@ int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin, | |||
| } | |||
| lite::STATUS ret; | |||
| if (src_tensor->data_type() == kObjectTypeTensorType && dst_tensor->data_type() == kObjectTypeTensorType) { | |||
| ret = MoveTensorLiteData(reinterpret_cast<lite::TensorList *>(dst_tensor), | |||
| ret = MoveTensorListData(reinterpret_cast<lite::TensorList *>(dst_tensor), | |||
| reinterpret_cast<lite::TensorList *>(src_tensor)); | |||
| } else { | |||
| ret = MoveTensorData(dst_tensor, src_tensor); | |||
| @@ -55,7 +55,13 @@ int CarryDataKernel::MoveData(std::vector<lite::Tensor *>::iterator dst_begin, | |||
| int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor) { | |||
| if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format() || | |||
| !(dst_tensor->shape() == src_tensor->shape() || (dst_tensor->shape().empty() && src_tensor->shape().empty()))) { | |||
| MS_LOG(ERROR) << "input tensor and output tensor is incompatible"; | |||
| MS_LOG(ERROR) << "input tensor and output tensor is incompatible."; | |||
| MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs " | |||
| << "output tensor data_type: " << dst_tensor->data_type() | |||
| << "input tensor format: " << src_tensor->format() << " vs " | |||
| << "output tensor format: " << dst_tensor->format() << "input tensor shape: " << src_tensor->shape() | |||
| << " vs " | |||
| << "output tensor shape: " << dst_tensor->shape(); | |||
| return RET_ERROR; | |||
| } | |||
| if (src_tensor->root_tensor() == nullptr) { | |||
| @@ -83,18 +89,19 @@ int CarryDataKernel::MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_ | |||
| return RET_OK; | |||
| } | |||
| int CarryDataKernel::MoveTensorLiteData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) { | |||
| int CarryDataKernel::MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor) { | |||
| // shape may change, because tensors.size() can be change in RunGraph | |||
| if (dst_tensor->data_type() != src_tensor->data_type() || dst_tensor->format() != src_tensor->format()) { | |||
| MS_LOG(ERROR) << "input tensorlist and output tensorlist data_type or format is incompatible"; | |||
| MS_LOG(ERROR) << "input tensor data_type: " << src_tensor->data_type() << " vs " | |||
| << "output tensor data_type: " << dst_tensor->data_type() | |||
| << "input tensor format: " << src_tensor->format() << " vs " | |||
| << "output tensor format: " << dst_tensor->format(); | |||
| return RET_ERROR; | |||
| } | |||
| if (dst_tensor->element_shape().empty()) { | |||
| dst_tensor->set_element_shape(src_tensor->element_shape()); | |||
| } else if (dst_tensor->element_shape() != src_tensor->element_shape()) { | |||
| MS_LOG(ERROR) << "input tensorlist and output tensorlist element shape is incompatible"; | |||
| return RET_ERROR; | |||
| } | |||
| // when tensorlist malloc is done. this need to check element_shape compatibility | |||
| dst_tensor->set_element_shape(src_tensor->element_shape()); | |||
| auto update_data_type = kTypeUnknown; | |||
| auto dst_tensor_data_type = dst_tensor->tensors_data_type(); | |||
| auto src_tensor_data_type = src_tensor->tensors_data_type(); | |||
| @@ -34,7 +34,7 @@ class CarryDataKernel : public LiteKernel { | |||
| int MoveData(std::vector<lite::Tensor *>::iterator dst_begin, std::vector<lite::Tensor *>::iterator dst_end, | |||
| std::vector<lite::Tensor *>::iterator src_begin, std::vector<lite::Tensor *>::iterator src_limit); | |||
| static int MoveTensorData(lite::Tensor *dst_tensor, lite::Tensor *src_tensor); | |||
| static int MoveTensorLiteData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor); | |||
| static int MoveTensorListData(lite::TensorList *dst_tensor, lite::TensorList *src_tensor); | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -146,6 +146,14 @@ int TensorListStackCPUKernel::MergeSubShape(const std::vector<int> &shape) { | |||
| } | |||
| int TensorListStackCPUKernel::Run() { | |||
| if (dtype_ == kTypeUnknown) { | |||
| dtype_ = input0_->tensors_data_type(); | |||
| #ifdef ENABLE_FP16 | |||
| if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) { | |||
| dtype_ = kNumberTypeFloat16; | |||
| } | |||
| #endif | |||
| } | |||
| if (CheckParam() != RET_OK) { | |||
| MS_LOG(ERROR) << "CheckParam failed!"; | |||
| return RET_ERROR; | |||
| @@ -169,7 +177,10 @@ int TensorListStackCPUKernel::Run() { | |||
| MS_ASSERT(out_data != nullptr); | |||
| for (int i = 0; i < num_element_; ++i) { | |||
| auto in_ptr = input0_->GetTensor(i); | |||
| MS_ASSERT(in_ptr != nullptr); | |||
| if (in_ptr == nullptr) { | |||
| MS_LOG(DEBUG) << "no need to stack."; | |||
| continue; | |||
| } | |||
| if (in_ptr->data_type() != kTypeUnknown) { | |||
| int data_size = in_ptr->ElementsNum() * lite::DataTypeSize(dtype_); | |||
| auto in_data = in_ptr->data_c(); | |||
| @@ -44,3 +44,4 @@ ml_video_edit_style_transfer_gongnongbing.onnx | |||
| ml_video_edit_style_transfer_starry.onnx | |||
| ml_video_edit_judge.onnx | |||
| ml_video_edit_vignet.onnx | |||
| ssd_mobilenet_v1_10.onnx;1,383,640,3 | |||
| @@ -23,7 +23,7 @@ namespace lite { | |||
| lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, | |||
| const onnx::NodeProto &onnx_node) { | |||
| MS_LOG(DEBUG) << "onnx NonZeroParser"; | |||
| auto attr = std::make_unique<schema::NonZeroT>(); | |||
| auto attr = std::make_unique<schema::WhereT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return nullptr; | |||
| @@ -33,7 +33,7 @@ lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto & | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_NonZero; | |||
| primitive->value.type = schema::PrimitiveType_Where; | |||
| primitive->value.value = attr.release(); | |||
| return PrimitiveC::Create(primitive.release()); | |||
| } | |||