| @@ -179,6 +179,13 @@ int Reshape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out | |||||
| std::vector<int> out_shape; | std::vector<int> out_shape; | ||||
| if (inputs_.size() == kDoubleNum) { | if (inputs_.size() == kDoubleNum) { | ||||
| auto shape_tensor = inputs_.at(1); | auto shape_tensor = inputs_.at(1); | ||||
| if (input->ElementsNum() == 1) { | |||||
| if (shape_tensor->shape().empty()) { | |||||
| MS_LOG(DEBUG) << "reshape to a scalar."; | |||||
| output->set_shape(out_shape); | |||||
| return RET_OK; | |||||
| } | |||||
| } | |||||
| if (shape_tensor->data_c() == nullptr) { | if (shape_tensor->data_c() == nullptr) { | ||||
| MS_LOG(INFO) << "Do infer shape in runtime."; | MS_LOG(INFO) << "Do infer shape in runtime."; | ||||
| return RET_INFER_INVALID; | return RET_INFER_INVALID; | ||||
| @@ -52,8 +52,8 @@ int Split::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu | |||||
| MS_LOG(ERROR) << "new primitiveT value failed"; | MS_LOG(ERROR) << "new primitiveT value failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->splitDim = GetValue<int32_t>(prim.GetAttr("axis")); | |||||
| attr->numberSplit = GetValue<int32_t>(prim.GetAttr("output_num")); | |||||
| attr->splitDim = CastToInt(prim.GetAttr("axis")).front(); | |||||
| attr->numberSplit = CastToInt(prim.GetAttr("output_num")).front(); | |||||
| this->primitive_->value.value = attr; | this->primitive_->value.value = attr; | ||||
| if (this->primitive_->value.value == nullptr) { | if (this->primitive_->value.value == nullptr) { | ||||
| MS_LOG(ERROR) << "primitive value is nullptr"; | MS_LOG(ERROR) << "primitive value is nullptr"; | ||||
| @@ -177,6 +177,15 @@ constexpr size_t kStridedSliceInputNum = 1; | |||||
| constexpr size_t kStridedSliceMultiInputNumMin = 3; | constexpr size_t kStridedSliceMultiInputNumMin = 3; | ||||
| constexpr size_t kStridedSliceMultiInputNumMax = 5; | constexpr size_t kStridedSliceMultiInputNumMax = 5; | ||||
| } // namespace | } // namespace | ||||
| bool StridedSlice::CheckInputs(std::vector<lite::Tensor *> inputs_) { | |||||
| for (size_t i = 1; i < inputs_.size(); ++i) { | |||||
| if (inputs_[i]->data_c() == nullptr) { | |||||
| MS_LOG(DEBUG) << "strided_slice has input from other node, which only can be obtained when running."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void StridedSlice::ApplyNewAxisMask() { | void StridedSlice::ApplyNewAxisMask() { | ||||
| for (size_t i = 0; i < new_axis_mask_.size(); i++) { | for (size_t i = 0; i < new_axis_mask_.size(); i++) { | ||||
| @@ -365,6 +374,10 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit | |||||
| strides_.emplace_back((GetStride())[i]); | strides_.emplace_back((GetStride())[i]); | ||||
| } | } | ||||
| } | } | ||||
| if (!CheckInputs(inputs)) { | |||||
| MS_LOG(DEBUG) << "Do infer shape in runtime."; | |||||
| return RET_INFER_INVALID; | |||||
| } | |||||
| if (inputs.size() == 4) { | if (inputs.size() == 4) { | ||||
| // input order: input, begins, ends, strides. | // input order: input, begins, ends, strides. | ||||
| auto begin_tensor = inputs.at(1); | auto begin_tensor = inputs.at(1); | ||||
| @@ -47,6 +47,7 @@ class StridedSlice : public PrimitiveC { | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | ||||
| bool CheckInputs(std::vector<lite::Tensor *> inputs_); | |||||
| int GetBeginMask() const; | int GetBeginMask() const; | ||||
| int GetEndMask() const; | int GetEndMask() const; | ||||
| int GetEllipsisMask() const; | int GetEllipsisMask() const; | ||||
| @@ -81,7 +81,6 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std:: | |||||
| STATUS TfliteModelParser::ConvertOps() { | STATUS TfliteModelParser::ConvertOps() { | ||||
| const auto &tflite_subgraph = tflite_model_->subgraphs.front(); | const auto &tflite_subgraph = tflite_model_->subgraphs.front(); | ||||
| const auto &tflite_model_buffers = tflite_model_->buffers; | |||||
| NoSupportOp::GetInstance()->SetFmkType("TFLITE"); | NoSupportOp::GetInstance()->SetFmkType("TFLITE"); | ||||
| STATUS status = RET_OK; | STATUS status = RET_OK; | ||||
| int op_idx = 0; | int op_idx = 0; | ||||
| @@ -117,6 +116,9 @@ STATUS TfliteModelParser::ConvertOps() { | |||||
| std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<lite::PrimitiveC>(primitiveC))}; | std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<lite::PrimitiveC>(primitiveC))}; | ||||
| // parse inputs | // parse inputs | ||||
| for (auto input_idx : op->inputs) { | for (auto input_idx : op->inputs) { | ||||
| if (tflite_op_type == tflite::BuiltinOperator_FULLY_CONNECTED && input_idx == -1) { | |||||
| continue; | |||||
| } | |||||
| if (input_idx < 0) { | if (input_idx < 0) { | ||||
| input_idx += tflite_subgraph->tensors.size(); | input_idx += tflite_subgraph->tensors.size(); | ||||
| } | } | ||||
| @@ -126,18 +128,14 @@ STATUS TfliteModelParser::ConvertOps() { | |||||
| continue; | continue; | ||||
| } | } | ||||
| // const tensor | // const tensor | ||||
| if (!tflite_model_buffers.at(input_tensor->buffer)->data.empty()) { | |||||
| auto parameter = func_graph_->add_parameter(); | |||||
| status = ConvertConstTensor(input_tensor.get(), parameter.get()); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; | |||||
| return status; | |||||
| } | |||||
| op_inputs.emplace_back(parameter); | |||||
| nodes_.insert(std::pair(input_idx, parameter)); | |||||
| continue; | |||||
| auto parameter = func_graph_->add_parameter(); | |||||
| status = ConvertConstTensor(input_tensor.get(), parameter.get()); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; | |||||
| return status; | |||||
| } | } | ||||
| MS_LOG(WARNING) << "tensor " << input_idx << " is neither a node output nor a weight tensor."; | |||||
| op_inputs.emplace_back(parameter); | |||||
| nodes_.insert(std::pair(input_idx, parameter)); | |||||
| } | } | ||||
| auto new_cnode = func_graph_->NewCNode(op_inputs); | auto new_cnode = func_graph_->NewCNode(op_inputs); | ||||
| new_cnode->set_fullname_with_scope(op_name); | new_cnode->set_fullname_with_scope(op_name); | ||||
| @@ -268,6 +266,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs() { | |||||
| auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); | auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); | ||||
| make_tuple_inputs.emplace_back(make_tuple_prim); | make_tuple_inputs.emplace_back(make_tuple_prim); | ||||
| for (auto outputNode : tflite_subgraph->outputs) { | for (auto outputNode : tflite_subgraph->outputs) { | ||||
| outputNode = outputNode < 0 ? outputNode + tflite_subgraph->tensors.size() : outputNode; | |||||
| auto cnode = nodes_.at(outputNode); | auto cnode = nodes_.at(outputNode); | ||||
| if (nullptr == cnode) { | if (nullptr == cnode) { | ||||
| MS_LOG(ERROR) << "Can't find input node."; | MS_LOG(ERROR) << "Can't find input node."; | ||||
| @@ -296,9 +295,12 @@ STATUS TfliteModelParser::ConvertGraphOutputs() { | |||||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| int outputNode = tflite_subgraph->outputs.front() < 0 | |||||
| ? static_cast<int>(tflite_subgraph->outputs.front() + tflite_subgraph->tensors.size()) | |||||
| : static_cast<int>(tflite_subgraph->outputs.front()); | |||||
| auto valueNode = NewValueNode(returnPrim); | auto valueNode = NewValueNode(returnPrim); | ||||
| std::vector<AnfNodePtr> op_inputs{valueNode}; | std::vector<AnfNodePtr> op_inputs{valueNode}; | ||||
| auto cnode = nodes_.at(tflite_subgraph->outputs.front()); | |||||
| auto cnode = nodes_.at(outputNode); | |||||
| if (nullptr == cnode) { | if (nullptr == cnode) { | ||||
| MS_LOG(ERROR) << "Can't find input node."; | MS_LOG(ERROR) << "Can't find input node."; | ||||
| return RET_NOT_FIND_OP; | return RET_NOT_FIND_OP; | ||||
| @@ -345,8 +347,8 @@ STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Para | |||||
| } | } | ||||
| std::memcpy(tensor_data, data.data(), size); | std::memcpy(tensor_data, data.data(), size); | ||||
| param_value->SetTensorData(tensor_data, size); | param_value->SetTensorData(tensor_data, size); | ||||
| parameter->set_default_param(param_value); | |||||
| } | } | ||||
| parameter->set_default_param(param_value); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -50,8 +50,15 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) { | |||||
| } | } | ||||
| auto data_node = depthwise_cnode->input(kConvInputIndex)->abstract(); | auto data_node = depthwise_cnode->input(kConvInputIndex)->abstract(); | ||||
| if (data_node == nullptr) { | |||||
| MS_LOG(ERROR) << "the node input is invalid."; | |||||
| return false; | |||||
| } | |||||
| auto data_shape = utils::cast<abstract::ShapePtr>(data_node->GetShapeTrack())->shape(); | auto data_shape = utils::cast<abstract::ShapePtr>(data_node->GetShapeTrack())->shape(); | ||||
| if (data_shape.empty()) { | |||||
| MS_LOG(DEBUG) << "the tensor's shape is dynamic."; | |||||
| return true; | |||||
| } | |||||
| auto conv_attr = std::make_unique<schema::Conv2DT>(); | auto conv_attr = std::make_unique<schema::Conv2DT>(); | ||||
| if (conv_attr == nullptr) { | if (conv_attr == nullptr) { | ||||
| MS_LOG(ERROR) << "conv_attr is null"; | MS_LOG(ERROR) << "conv_attr is null"; | ||||
| @@ -89,7 +89,7 @@ STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto ret = memcpy_s(tensor_data, new_value->tensor_size(), param_value->tensor_addr(), param_value->tensor_size()); | auto ret = memcpy_s(tensor_data, new_value->tensor_size(), param_value->tensor_addr(), param_value->tensor_size()); | ||||
| if (ret != EOK) { | |||||
| if (new_value->tensor_size() != 0 && ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | MS_LOG(ERROR) << "memcpy error: " << ret; | ||||
| delete[] tensor_data; | delete[] tensor_data; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -163,7 +163,7 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| ret = memcpy_s(tensor->MutableData(), tensor->Size(), param_value->tensor_addr(), param_value->tensor_size()); | ret = memcpy_s(tensor->MutableData(), tensor->Size(), param_value->tensor_addr(), param_value->tensor_size()); | ||||
| if (ret != EOK) { | |||||
| if (tensor->Size() != 0 && ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | MS_LOG(ERROR) << "memcpy error: " << ret; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||