| @@ -457,6 +457,7 @@ table Min { | |||||
| table Slice { | table Slice { | ||||
| format: Format = 0; | format: Format = 0; | ||||
| axes: [int]; | |||||
| begin: [int]; | begin: [int]; | ||||
| size: [int]; | size: [int]; | ||||
| } | } | ||||
| @@ -65,10 +65,6 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||||
| } | } | ||||
| auto indices_shape = indices->shape(); | auto indices_shape = indices->shape(); | ||||
| int indices_rank = indices_shape.size(); | int indices_rank = indices_shape.size(); | ||||
| if (indices_rank < batch_dims + 1) { | |||||
| MS_LOG(ERROR) << "input[1]'s rank is less than batchDim + 1"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (batch_dims != 0) { | if (batch_dims != 0) { | ||||
| MS_LOG(ERROR) << "batchDims " << batch_dims << " != 0, which is not support"; | MS_LOG(ERROR) << "batchDims " << batch_dims << " != 0, which is not support"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -38,6 +38,7 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor: | |||||
| auto in_tensor = inputs_.front(); | auto in_tensor = inputs_.front(); | ||||
| auto out_tensor = outputs_.front(); | auto out_tensor = outputs_.front(); | ||||
| out_tensor->set_data_type(kNumberTypeInt32); | out_tensor->set_data_type(kNumberTypeInt32); | ||||
| out_tensor->SetFormat(schema::Format_NHWC); | |||||
| if (!GetInferFlag()) { | if (!GetInferFlag()) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,6 +29,7 @@ constexpr int kSliceOutputNum = 1; | |||||
| int Slice::GetFormat() const { return this->primitive_->value.AsSlice()->format; } | int Slice::GetFormat() const { return this->primitive_->value.AsSlice()->format; } | ||||
| std::vector<int> Slice::GetBegin() const { return this->primitive_->value.AsSlice()->begin; } | std::vector<int> Slice::GetBegin() const { return this->primitive_->value.AsSlice()->begin; } | ||||
| std::vector<int> Slice::GetSize() const { return this->primitive_->value.AsSlice()->size; } | std::vector<int> Slice::GetSize() const { return this->primitive_->value.AsSlice()->size; } | ||||
| std::vector<int> Slice::GetAxes() const { return this->primitive_->value.AsSlice()->axes; } | |||||
| void Slice::SetFormat(int format) { this->primitive_->value.AsSlice()->format = (schema::Format)format; } | void Slice::SetFormat(int format) { this->primitive_->value.AsSlice()->format = (schema::Format)format; } | ||||
| void Slice::SetBegin(const std::vector<int> &begin) { this->primitive_->value.AsSlice()->begin = begin; } | void Slice::SetBegin(const std::vector<int> &begin) { this->primitive_->value.AsSlice()->begin = begin; } | ||||
| @@ -45,9 +46,14 @@ std::vector<int> Slice::GetSize() const { | |||||
| auto fb_vector = this->primitive_->value_as_Slice()->size(); | auto fb_vector = this->primitive_->value_as_Slice()->size(); | ||||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | return std::vector<int>(fb_vector->begin(), fb_vector->end()); | ||||
| } | } | ||||
| std::vector<int> Slice::GetAxes() const { | |||||
| auto fb_vector = this->primitive_->value_as_Slice()->axes(); | |||||
| return std::vector<int>(fb_vector->begin(), fb_vector->end()); | |||||
| } | |||||
| #endif | #endif | ||||
| std::vector<int> Slice::GetPostProcessBegin() const { return this->begin; } | |||||
| std::vector<int> Slice::GetPostProcessSize() const { return this->size; } | |||||
| int Slice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) { | int Slice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | MS_ASSERT(this->primitive_ != nullptr); | ||||
| if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) { | if (inputs.size() != kSliceInputNum || outputs.size() != kSliceOutputNum) { | ||||
| @@ -61,30 +67,37 @@ int Slice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<li | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| auto input_shape = input->shape(); | auto input_shape = input->shape(); | ||||
| std::vector<int32_t> slice_begin(GetBegin().begin(), GetBegin().end()); | |||||
| std::vector<int32_t> slice_size(GetSize().begin(), GetSize().end()); | |||||
| std::vector<int32_t> slice_begin(GetBegin()); | |||||
| std::vector<int32_t> slice_size(GetSize()); | |||||
| std::vector<int32_t> slice_axes(GetAxes()); | |||||
| std::vector<int32_t> output_shape(input_shape.size()); | std::vector<int32_t> output_shape(input_shape.size()); | ||||
| begin.assign(input_shape.size(), 0); | |||||
| size.assign(input_shape.size(), -1); | |||||
| for (size_t i = 0; i < slice_axes.size(); ++i) { | |||||
| begin[slice_axes[i]] = slice_begin[i]; | |||||
| size[slice_axes[i]] = slice_size[i]; | |||||
| } | |||||
| for (size_t i = 0; i < input_shape.size(); ++i) { | for (size_t i = 0; i < input_shape.size(); ++i) { | ||||
| if (slice_size[i] < 0 && slice_size[i] != -1) { | |||||
| MS_LOG(ERROR) << "Invalid size input!size[" << i << "]=" << slice_size[i]; | |||||
| if (size[i] < 0 && size[i] != -1) { | |||||
| MS_LOG(ERROR) << "Invalid size input!size[" << i << "]=" << size[i]; | |||||
| return RET_PARAM_INVALID; | return RET_PARAM_INVALID; | ||||
| } | } | ||||
| if (slice_begin[i] < 0) { | |||||
| MS_LOG(ERROR) << "Invalid begin input " << slice_begin[i] << " which should be >= 0"; | |||||
| if (begin[i] < 0) { | |||||
| MS_LOG(ERROR) << "Invalid begin input " << begin[i] << " which should be >= 0"; | |||||
| return RET_PARAM_INVALID; | return RET_PARAM_INVALID; | ||||
| } | } | ||||
| if (input_shape[i] <= slice_begin[i]) { | |||||
| MS_LOG(ERROR) << "Invalid begin input!begin[" << i << "]=" << slice_begin[i] | |||||
| if (input_shape[i] <= begin[i]) { | |||||
| MS_LOG(ERROR) << "Invalid begin input!begin[" << i << "]=" << begin[i] | |||||
| << " which should be <= " << input_shape[i]; | << " which should be <= " << input_shape[i]; | ||||
| return RET_PARAM_INVALID; | return RET_PARAM_INVALID; | ||||
| } | } | ||||
| if (slice_size[i] > (input_shape[i] - slice_begin[i])) { | |||||
| MS_LOG(ERROR) << "Invalid size input " << slice_size[i] | |||||
| << " which should be <= " << input_shape[i] - slice_begin[i]; | |||||
| if (size[i] > (input_shape[i] - begin[i])) { | |||||
| MS_LOG(ERROR) << "Invalid size input " << size[i] | |||||
| << " which should be <= " << input_shape[i] - begin[i]; | |||||
| return RET_PARAM_INVALID; | return RET_PARAM_INVALID; | ||||
| } | } | ||||
| output_shape[i] = slice_size[i] < 0 ? input_shape[i] - slice_begin[i] : slice_size[i]; | |||||
| output_shape[i] = size[i] < 0 ? input_shape[i] - begin[i] : size[i]; | |||||
| } | } | ||||
| outputs[0]->set_shape(output_shape); | outputs[0]->set_shape(output_shape); | ||||
| @@ -41,6 +41,14 @@ class Slice : public PrimitiveC { | |||||
| int GetFormat() const; | int GetFormat() const; | ||||
| std::vector<int> GetBegin() const; | std::vector<int> GetBegin() const; | ||||
| std::vector<int> GetSize() const; | std::vector<int> GetSize() const; | ||||
| std::vector<int> GetAxes() const; | |||||
| // due to difference between tflite and onnx, when inferring shape, construct new parameters of begin and size. | |||||
| // when running graph, we need to obtain new begins and sizes using the two function as below. | |||||
| std::vector<int> GetPostProcessBegin() const; | |||||
| std::vector<int> GetPostProcessSize() const; | |||||
| protected: | |||||
| std::vector<int> begin = {0}; | |||||
| std::vector<int> size = {-1}; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1010,8 +1010,8 @@ OpParameter *PopulateSliceParameter(const mindspore::lite::PrimitiveC *primitive | |||||
| memset(slice_param, 0, sizeof(SliceParameter)); | memset(slice_param, 0, sizeof(SliceParameter)); | ||||
| auto param = reinterpret_cast<mindspore::lite::Slice *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | auto param = reinterpret_cast<mindspore::lite::Slice *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | ||||
| slice_param->op_parameter_.type_ = primitive->Type(); | slice_param->op_parameter_.type_ = primitive->Type(); | ||||
| auto param_begin = param->GetBegin(); | |||||
| auto param_size = param->GetSize(); | |||||
| auto param_begin = param->GetPostProcessBegin(); | |||||
| auto param_size = param->GetPostProcessSize(); | |||||
| if (param_begin.size() != param_size.size()) { | if (param_begin.size() != param_size.size()) { | ||||
| free(slice_param); | free(slice_param); | ||||
| return nullptr; | return nullptr; | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "nnacl/fp32/slice.h" | #include "nnacl/fp32/slice.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "src/ops/slice.h" | |||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| @@ -40,7 +41,15 @@ int SliceLaunch(int thread_id, LiteParallelGroupEnv *penv, void *cdata) { | |||||
| } // namespace | } // namespace | ||||
| int SliceCPUKernel::ReSize() { | int SliceCPUKernel::ReSize() { | ||||
| auto *param = reinterpret_cast<SliceParameter *>(op_parameter_); | |||||
| auto primitive_slice = reinterpret_cast<const mindspore::lite::Slice *>(primitive_); | |||||
| auto begin = primitive_slice->GetPostProcessBegin(); | |||||
| auto size = primitive_slice->GetPostProcessSize(); | |||||
| auto param = reinterpret_cast<SliceParameter *>(op_parameter_); | |||||
| param->param_length_ = in_tensors_[0]->shape().size(); | |||||
| for (int i = 0; i < param->param_length_; ++i) { | |||||
| param->begin_[i] = begin[i]; | |||||
| param->size_[i] = size[i]; | |||||
| } | |||||
| auto input_shape = in_tensors_[0]->shape(); | auto input_shape = in_tensors_[0]->shape(); | ||||
| if (static_cast<int>(input_shape.size()) != param->param_length_) { | if (static_cast<int>(input_shape.size()) != param->param_length_) { | ||||
| MS_LOG(ERROR) << "Input begin's lenth " << param->param_length_ << "is not equal to input shape size " | MS_LOG(ERROR) << "Input begin's lenth " << param->param_length_ << "is not equal to input shape size " | ||||
| @@ -24,8 +24,8 @@ namespace lite { | |||||
| namespace converter { | namespace converter { | ||||
| Flags::Flags() { | Flags::Flags() { | ||||
| AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MS", ""); | AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MS", ""); | ||||
| AddFlag(&Flags::modelFile, "modelFile", "Input model file path. TFLITE: *.tflite | CAFFE: *.prototxt | MS: *.mindir", | |||||
| ""); | |||||
| AddFlag(&Flags::modelFile, "modelFile", | |||||
| "Input model file path. TFLITE: *.tflite | CAFFE: *.prototxt | MS: *.mindir | ONNX: *.onnx", ""); | |||||
| AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); | AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); | ||||
| AddFlag(&Flags::weightFile, "weightFile", | AddFlag(&Flags::weightFile, "weightFile", | ||||
| "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); | "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); | ||||
| @@ -41,6 +41,10 @@ Flags::Flags() { | |||||
| } | } | ||||
| int Flags::Init(int argc, const char **argv) { | int Flags::Init(int argc, const char **argv) { | ||||
| if (argc == 1) { | |||||
| std::cout << this->Usage() << std::endl; | |||||
| return RET_SUCCESS_EXIT; | |||||
| } | |||||
| Option<std::string> err = this->ParseFlags(argc, argv); | Option<std::string> err = this->ParseFlags(argc, argv); | ||||
| if (err.IsSome()) { | if (err.IsSome()) { | ||||
| @@ -121,7 +121,8 @@ STATUS MulAddFusionPass::AddNewScaleNode(MetaGraphT *graph, const std::unique_pt | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // NHWC | // NHWC | ||||
| scaleParam->axis = -1; | |||||
| int shape_size = graph->allTensors.at(addBiasIndex)->dims.size(); | |||||
| scaleParam->axis = 0 - shape_size; | |||||
| mulNode->primitive->value.value = scaleParam.release(); | mulNode->primitive->value.value = scaleParam.release(); | ||||
| mulNode->inputIndex.push_back(addBiasIndex); | mulNode->inputIndex.push_back(addBiasIndex); | ||||
| if (addNode->primitive->value.AsAdd()->activationType != ActivationType_NO_ACTIVATION) { | if (addNode->primitive->value.AsAdd()->activationType != ActivationType_NO_ACTIVATION) { | ||||
| @@ -38,22 +38,38 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::No | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<int> axes; | |||||
| std::vector<int> starts; | |||||
| std::vector<int> ends; | |||||
| for (const auto &onnx_node_attr : onnx_node.attribute()) { | for (const auto &onnx_node_attr : onnx_node.attribute()) { | ||||
| const auto &attribute_name = onnx_node_attr.name(); | const auto &attribute_name = onnx_node_attr.name(); | ||||
| if (attribute_name == "starts") { | if (attribute_name == "starts") { | ||||
| const int size = onnx_node_attr.ints_size(); | |||||
| MS_LOG(INFO) << "SLICE starts size " << size; | |||||
| for (int i = 0; i < size; ++i) { | |||||
| attr->begin.emplace_back(static_cast<int32_t>(onnx_node_attr.ints(i))); | |||||
| const int num = onnx_node_attr.ints_size(); | |||||
| starts.clear(); | |||||
| for (int i = 0; i < num; ++i) { | |||||
| starts.push_back(static_cast<int>(onnx_node_attr.ints()[i])); | |||||
| } | |||||
| } else if (attribute_name == "axes") { | |||||
| const int num = onnx_node_attr.ints_size(); | |||||
| axes.clear(); | |||||
| for (int i = 0; i < num; ++i) { | |||||
| axes.push_back(static_cast<int>(onnx_node_attr.ints()[i])); | |||||
| } | } | ||||
| } else if (attribute_name == "ends") { | } else if (attribute_name == "ends") { | ||||
| const int size = onnx_node_attr.ints_size(); | |||||
| for (int i = 0; i < size; ++i) { | |||||
| attr->size.emplace_back(static_cast<int32_t>(onnx_node_attr.ints(i))); | |||||
| const int num = onnx_node_attr.ints_size(); | |||||
| ends.clear(); | |||||
| for (int i = 0; i < num; ++i) { | |||||
| ends.push_back(static_cast<int>(onnx_node_attr.ints()[i])); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| std::vector<int> sizes(starts.size(), -1); | |||||
| for (size_t i = 0; i < starts.size(); ++i) { | |||||
| sizes[i] = (ends[i] < 0 ? ends[i] : ends[i] - starts[i]); | |||||
| } | |||||
| attr->axes = axes; | |||||
| attr->begin = starts; | |||||
| attr->size = sizes; | |||||
| op->primitive->value.type = schema::PrimitiveType_Slice; | op->primitive->value.type = schema::PrimitiveType_Slice; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -55,7 +55,12 @@ STATUS TfliteSliceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite | |||||
| MS_LOG(ERROR) << "get slice -> size failed"; | MS_LOG(ERROR) << "get slice -> size failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::vector<int> axes; | |||||
| axes.clear(); | |||||
| for (size_t i = 0; i < attr->begin.size(); ++i) { | |||||
| axes.push_back(i); | |||||
| } | |||||
| attr->axes = axes; | |||||
| op->primitive->value.type = schema::PrimitiveType_Slice; | op->primitive->value.type = schema::PrimitiveType_Slice; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| @@ -72,8 +72,7 @@ const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) { | |||||
| } | } | ||||
| const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { | const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { | ||||
| auto parameter = func_graph->add_parameter(); | auto parameter = func_graph->add_parameter(); | ||||
| std::vector<int> shape; | |||||
| std::copy(tensor->shape().begin(), tensor->shape().end(), std::back_inserter(shape)); | |||||
| std::vector<int> shape(tensor->shape()); | |||||
| auto type_id = static_cast<TypeId>(tensor->data_type()); | auto type_id = static_cast<TypeId>(tensor->data_type()); | ||||
| auto type_ptr = TypeIdToType(type_id); | auto type_ptr = TypeIdToType(type_id); | ||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | ||||
| @@ -160,6 +159,15 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An | |||||
| MS_LOG(ERROR) << "lite_primitive is nullptr"; | MS_LOG(ERROR) << "lite_primitive is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // here, input_tensor's format need to be transposed nhwc according to fmkType, | |||||
| // but for the time being, we only transpose the tensor with 0/1/2/3D. | |||||
| // Others should be added in future. | |||||
| for (size_t j = 0; j < input_tensors.size(); ++j) { | |||||
| input_tensors[j]->SetFormat(schema::Format_NHWC); | |||||
| if (input_tensors[j]->shape().size() == 4) { | |||||
| MS_LOG(WARNING) << "init input_tensor format to nhwc"; | |||||
| } | |||||
| } | |||||
| lite_primitive->InferShape(input_tensors, output_tensors); | lite_primitive->InferShape(input_tensors, output_tensors); | ||||
| auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive.get()); | auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive.get()); | ||||
| if (lite_kernel == nullptr) { | if (lite_kernel == nullptr) { | ||||