diff --git a/mindspore/lite/nnacl/split.c b/mindspore/lite/nnacl/split.c index 03e7f1eddf..dfa7ad8dd0 100644 --- a/mindspore/lite/nnacl/split.c +++ b/mindspore/lite/nnacl/split.c @@ -50,6 +50,14 @@ int DoSplit(float *in_data, float **out_data, const int *input_shape, int offset split_which = i % num_split; split_times = i / num_split; int split_size = split_sizes[split_which]; + // support split size is -1 in the end. + if (split_size == -1) { + int split_dim_i = input_shape[split_dim]; + for (int j = 0; j < num_split - 1; ++j) { + split_dim_i -= split_sizes[j]; + } + split_size = split_dim_i; + } float *dst = out_data[split_which] + split_times * in_stride * split_size; (void)memcpy(dst, src, split_size * in_stride_bytes); src += split_size * in_stride; diff --git a/mindspore/lite/src/ops/split.cc b/mindspore/lite/src/ops/split.cc index a7bde44996..4798f54d88 100644 --- a/mindspore/lite/src/ops/split.cc +++ b/mindspore/lite/src/ops/split.cc @@ -88,7 +88,7 @@ int Split::InferShape(std::vector inputs_, std::vectorshape().size() - 1 : GetSplitDim(); std::vector input_shape = input->shape(); std::vector size_split; for (size_t i = 0; i < GetSizeSplits().size(); ++i) { @@ -97,7 +97,15 @@ int Split::InferShape(std::vector inputs_, std::vector output_shape; output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end()); - auto split_dim_i = size_split.empty() ? input_shape[split_dim] / number_split : size_split[i]; + int split_dim_i = input_shape[split_dim]; + // support split size is -1 in the end. + if (i == number_split - 1 && size_split[i] == -1) { + for (size_t j = 0; j < size_split.size() - 1; ++j) { + split_dim_i -= size_split[j]; + } + } else { + split_dim_i = size_split.empty() ? input_shape[split_dim] / number_split : size_split[i]; + } output_shape[split_dim] = split_dim_i; outputs_[i]->set_shape(output_shape); outputs_[i]->set_data_type(input->data_type()); diff --git a/mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt b/mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt index 2faf790b21..d10837bdee 100644 --- a/mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt +++ b/mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt @@ -29,4 +29,6 @@ add_library(caffe_parser_mid OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/caffe_permute_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/caffe_tile_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/caffe_tanh_parser.cc - ${CMAKE_CURRENT_SOURCE_DIR}/caffe_exp_parser.cc) + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_exp_parser.cc + ${CMAKE_CURRENT_SOURCE_DIR}/caffe_slice_parser.cc + ) diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_slice_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_slice_parser.cc new file mode 100644 index 0000000000..44708fd5a3 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_slice_parser.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/caffe/caffe_slice_parser.h" +#include + +namespace mindspore { +namespace lite { +STATUS CaffeSliceParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { + MS_LOG(DEBUG) << "parse CaffeSliceParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + const caffe::SliceParameter &slice_param = proto.slice_param(); + + if (!slice_param.slice_point().empty()) { + attr->numberSplit = slice_param.slice_point_size() + 1; + std::vector size_splits; + for (int i = 0; i < slice_param.slice_point_size(); ++i) { + if (i == 0) { + size_splits.push_back(slice_param.slice_point(i)); + } else { + size_splits.push_back(slice_param.slice_point(i) - slice_param.slice_point(i - 1)); + } + } + size_splits.push_back(-1); + attr->sizeSplits = size_splits; + } + + // The axis along which to slice -- may be negative to index from the end (e.g., -1 for the last axis). + if (slice_param.has_axis()) { + attr->splitDim = slice_param.axis(); + } else if (slice_param.has_slice_dim()) { + attr->splitDim = slice_param.slice_dim(); + } + op->name = proto.name(); + op->primitive->value.type = schema::PrimitiveType_Split; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +CaffeNodeRegistrar g_caffeSliceParser("Slice", new CaffeSliceParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_slice_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_slice_parser.h new file mode 100644 index 0000000000..bb046a4c91 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_slice_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MMINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_SLICE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_SLICE_PARSER_H_ + +#include +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser.h" +#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeSliceParser : public CaffeNodeParser { + public: + CaffeSliceParser() : CaffeNodeParser("slice") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_SLICE_PARSER_H_