From: @cjh9368 Reviewed-by: @zhanghaibo5 Signed-off-by:tags/v1.2.0-rc1
| @@ -987,6 +987,22 @@ int ElementMaximumInt(const int *in0, const int *in1, int *out, int size) { | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int ElementMinimumInt(const int *input0, const int *input1, int *output, const int element_size) { | |||||
| int index = 0; | |||||
| #ifdef ENABLE_NEON | |||||
| for (; index <= element_size - 4; index += C4NUM) { | |||||
| int32x4_t vin0 = vld1q_s32(input0 + index); | |||||
| int32x4_t vin1 = vld1q_s32(input1 + index); | |||||
| int32x4_t vout = vminq_s32(vin0, vin1); | |||||
| vst1q_s32(output + index, vout); | |||||
| } | |||||
| #endif | |||||
| for (; index < element_size; index++) { | |||||
| output[index] = input0[index] > input1[index] ? input1[index] : input0[index]; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int BroadcastMaximum(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, | int BroadcastMaximum(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, | ||||
| ArithmeticParameter *param) { | ArithmeticParameter *param) { | ||||
| TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param); | TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param); | ||||
| @@ -95,6 +95,7 @@ int ElementSquaredDifference(const float *in0, const float *in1, float *out, int | |||||
| int ElementMaximum(const float *in0, const float *in1, float *out, int size); | int ElementMaximum(const float *in0, const float *in1, float *out, int size); | ||||
| int ElementMinimum(const float *in0, const float *in1, float *out, int size); | int ElementMinimum(const float *in0, const float *in1, float *out, int size); | ||||
| int ElementMaximumInt(const int *in0, const int *in1, int *out, int size); | int ElementMaximumInt(const int *in0, const int *in1, int *out, int size); | ||||
| int ElementMinimumInt(const int *input0, const int *input1, int *output, const int element_size); | |||||
| int BroadcastMaximum(const float *in0, const float *in1, float *tile_input0, float *tile_input1, float *out, int size, | int BroadcastMaximum(const float *in0, const float *in1, float *tile_input0, float *tile_input1, float *out, int size, | ||||
| ArithmeticParameter *param); | ArithmeticParameter *param); | ||||
| @@ -165,6 +165,7 @@ int TensorListStack::InferShape(std::vector<lite::Tensor *> inputs_, std::vector | |||||
| output->set_data_type(input0->tensors_data_type()); | output->set_data_type(input0->tensors_data_type()); | ||||
| output_shape_.insert(output_shape_.begin(), input0->ElementsNum()); | output_shape_.insert(output_shape_.begin(), input0->ElementsNum()); | ||||
| output->set_shape(output_shape_); | output->set_shape(output_shape_); | ||||
| output->set_format(input0->format()); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -169,6 +169,7 @@ void ArithmeticCPUKernel::InitRunFunction() { | |||||
| break; | break; | ||||
| case PrimitiveType_Minimum: | case PrimitiveType_Minimum: | ||||
| arithmetic_run_ = ElementMinimum; | arithmetic_run_ = ElementMinimum; | ||||
| arithmetic_run_int_ = ElementMinimumInt; | |||||
| break; | break; | ||||
| case PrimitiveType_FloorDiv: | case PrimitiveType_FloorDiv: | ||||
| arithmetic_run_ = ElementFloorDiv; | arithmetic_run_ = ElementFloorDiv; | ||||
| @@ -30,7 +30,17 @@ namespace mindspore::kernel { | |||||
| int TensorListReserveCPUKernel::Init() { return RET_OK; } | int TensorListReserveCPUKernel::Init() { return RET_OK; } | ||||
| int TensorListReserveCPUKernel::Run() { | int TensorListReserveCPUKernel::Run() { | ||||
| auto input0 = in_tensors_.at(0); | |||||
| auto input1 = in_tensors_.at(1); | |||||
| int num_elements = reinterpret_cast<int *>(input1->data_c())[0]; | |||||
| auto output = reinterpret_cast<lite::TensorList *>(out_tensors_[0]); | auto output = reinterpret_cast<lite::TensorList *>(out_tensors_[0]); | ||||
| if (output->tensors().size() < (uint32_t)num_elements) { | |||||
| auto ele_shape_ptr = reinterpret_cast<int *>(input0->data_c()); | |||||
| std::vector<std::vector<int> > tmp_shape(num_elements, std::vector<int>()); | |||||
| output->set_element_shape(std::vector<int>(ele_shape_ptr, ele_shape_ptr + input0->ElementsNum())); | |||||
| output->set_shape(std::vector<int>(1, num_elements)); | |||||
| output->MallocTensorListData(kTypeUnknown, tmp_shape); | |||||
| } | |||||
| output->set_tensors_data_type(element_dtype_); | output->set_tensors_data_type(element_dtype_); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -83,7 +83,14 @@ int TensorListSetItemCPUKernel::Run() { | |||||
| auto src = input0_->GetTensor(i); | auto src = input0_->GetTensor(i); | ||||
| auto dst = output0_->GetTensor(i); | auto dst = output0_->GetTensor(i); | ||||
| MS_ASSERT(src != nullptr); | MS_ASSERT(src != nullptr); | ||||
| MS_ASSERT(dst != nullptr); | |||||
| // merge move data will delete tensors | |||||
| if (dst == nullptr) { | |||||
| dst = lite::Tensor::CopyTensor(*src, src->data_c() != nullptr); | |||||
| auto &tensors = output0_->tensors(); | |||||
| tensors.emplace_back(dst); | |||||
| continue; | |||||
| } | |||||
| if (src->data_type() != kTypeUnknown) { | if (src->data_type() != kTypeUnknown) { | ||||
| if (src->Size() != dst->Size()) { | if (src->Size() != dst->Size()) { | ||||
| MS_LOG(ERROR) << "src->Size():" << src->Size() << " must be equal to dst->Size():" << dst->Size(); | MS_LOG(ERROR) << "src->Size():" << src->Size() << " must be equal to dst->Size():" << dst->Size(); | ||||
| @@ -288,7 +288,6 @@ int Tensor::set_root_tensor(Tensor *tensor) { | |||||
| this->shape_ = this->root_tensor_->shape_; | this->shape_ = this->root_tensor_->shape_; | ||||
| this->format_ = this->root_tensor_->format_; | this->format_ = this->root_tensor_->format_; | ||||
| this->data_type_ = this->root_tensor_->data_type_; | this->data_type_ = this->root_tensor_->data_type_; | ||||
| this->allocator_ = this->root_tensor_->allocator_; | |||||
| this->category_ = this->root_tensor_->category_; | this->category_ = this->root_tensor_->category_; | ||||
| this->quant_params_ = this->root_tensor_->quant_params_; | this->quant_params_ = this->root_tensor_->quant_params_; | ||||
| this->quant_clusters_ = this->root_tensor_->quant_clusters_; | this->quant_clusters_ = this->root_tensor_->quant_clusters_; | ||||
| @@ -264,8 +264,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| // init old node indecies | // init old node indecies | ||||
| auto old_nodes = GetGraphNodes(); | auto old_nodes = GetGraphNodes(); | ||||
| Optimizer selectOptimizer; | Optimizer selectOptimizer; | ||||
| selectOptimizer.AddPass(new (std::nothrow) SelectPass()); | |||||
| selectOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| selectOptimizer.AddPass(new (std::nothrow) SelectPass(graphDefT)); | |||||
| status = selectOptimizer.Run(graphDefT); | status = selectOptimizer.Run(graphDefT); | ||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | if (status != RET_OK && status != RET_NO_CHANGE) { | ||||
| MS_LOG(ERROR) << "Run switch graphPasses Failed"; | MS_LOG(ERROR) << "Run switch graphPasses Failed"; | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <map> | #include <map> | ||||
| #include <algorithm> | |||||
| #include "tools/converter/legacy_optimizer/graph/select_pass.h" | #include "tools/converter/legacy_optimizer/graph/select_pass.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| @@ -40,6 +41,52 @@ STATUS SelectPass::Run(mindspore::schema::MetaGraphT *graph) { | |||||
| MS_LOG(ERROR) << "node: " << node->name << "'s select pass failed: " << ret; | MS_LOG(ERROR) << "node: " << node->name << "'s select pass failed: " << ret; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| select_indices_.emplace_back(i); | |||||
| } | |||||
| int ret = RemoveSelectNodes(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "remove select nodes failed"; | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS SelectPass::RemoveSelectNodes() { | |||||
| std::sort(select_indices_.begin(), select_indices_.end(), std::greater<int>()); | |||||
| for (auto select_indice : select_indices_) { | |||||
| auto &node = graph_->nodes.at(select_indice); | |||||
| if (node->primitive->value.type != PrimitiveType_Select) { | |||||
| MS_LOG(ERROR) << "node " << node->name << " is not a select node"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int subgraph_idx = -1; | |||||
| for (size_t i = 0; i < graph_->subGraph.size(); i++) { | |||||
| if (IsContain(graph_->subGraph.at(i)->nodeIndices, select_indice)) { | |||||
| subgraph_idx = i; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (subgraph_idx == -1) { | |||||
| MS_LOG(ERROR) << "select node " << node->name << " is not belong to any subgraph"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| graph_->nodes.erase(graph_->nodes.begin() + select_indice); | |||||
| std::vector<uint32_t> new_node_indices; | |||||
| std::copy_if(graph_->subGraph.at(subgraph_idx)->nodeIndices.begin(), | |||||
| graph_->subGraph.at(subgraph_idx)->nodeIndices.end(), | |||||
| std::inserter(new_node_indices, new_node_indices.begin()), | |||||
| [&select_indice](int indice) { return (uint32_t)indice != select_indice; }); | |||||
| graph_->subGraph.at(subgraph_idx)->nodeIndices = new_node_indices; | |||||
| for (auto &subgraph : graph_->subGraph) { | |||||
| std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(), | |||||
| [&select_indice](uint32_t idx) { | |||||
| if (idx > select_indice) { | |||||
| return --idx; | |||||
| } | |||||
| return idx; | |||||
| }); | |||||
| } | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include <functional> | |||||
| #include "tools/common/graph_util.h" | #include "tools/common/graph_util.h" | ||||
| #include "tools/converter/optimizer.h" | #include "tools/converter/optimizer.h" | ||||
| @@ -28,9 +29,14 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| class SelectPass : public GraphPass { | class SelectPass : public GraphPass { | ||||
| public: | public: | ||||
| SelectPass() = default; | |||||
| explicit SelectPass(schema::MetaGraphT *graph) : graph_(graph) {} | |||||
| ~SelectPass() override = default; | ~SelectPass() override = default; | ||||
| STATUS Run(schema::MetaGraphT *graph) override; | STATUS Run(schema::MetaGraphT *graph) override; | ||||
| STATUS RemoveSelectNodes(); | |||||
| private: | |||||
| std::vector<uint32_t> select_indices_; | |||||
| schema::MetaGraphT *graph_ = nullptr; | |||||
| }; | }; | ||||
| class SingleSelectPass { | class SingleSelectPass { | ||||
| @@ -251,9 +251,34 @@ STATUS SingleSwitchPass::InsertMerge() { | |||||
| second_partial_node_->inputIndex.assign(switch_node_->outputIndex.begin(), | second_partial_node_->inputIndex.assign(switch_node_->outputIndex.begin(), | ||||
| switch_node_->outputIndex.begin() + switch_node_->outputIndex.size() / 2); | switch_node_->outputIndex.begin() + switch_node_->outputIndex.size() / 2); | ||||
| // skip tensor which is not any nodes' inputs to avoid body partial connect to merge input cnode | |||||
| std::vector<uint32_t> skip_input_tensors; | |||||
| for (auto input : const_input) { | |||||
| auto real_input = graph_->subGraph.at(second_subgraph_index_)->inputIndices.at(input); | |||||
| bool skip = true; | |||||
| for (auto &node : second_graph_nodes_) { | |||||
| if (IsContain(node->inputIndex, real_input)) { | |||||
| skip = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (skip) { | |||||
| auto &skip_tensor = graph_->allTensors.at(real_input); | |||||
| int partial_idx = GetSubgraphInputTensorIndex(graph_->subGraph.at(second_subgraph_index_), skip_tensor); | |||||
| skip_input_tensors.emplace_back(partial_idx); | |||||
| } | |||||
| } | |||||
| // concat body output to merge input | // concat body output to merge input | ||||
| second_partial_node_->outputIndex.assign(merge_node->inputIndex.begin() + merge_node->inputIndex.size() / 2, | |||||
| merge_node->inputIndex.end()); | |||||
| second_partial_node_->outputIndex.clear(); | |||||
| for (uint32_t merge_right_input = 0; merge_right_input < merge_node->inputIndex.size() / 2; merge_right_input++) { | |||||
| if (!IsContain(skip_input_tensors, merge_right_input)) { | |||||
| second_partial_node_->outputIndex.emplace_back( | |||||
| merge_node->inputIndex.at(merge_node->inputIndex.size() / 2 + merge_right_input)); | |||||
| } else { | |||||
| second_partial_node_->outputIndex.emplace_back(UINT32_MAX); | |||||
| } | |||||
| } | |||||
| graph_->nodes.push_back(std::move(merge_node)); | graph_->nodes.push_back(std::move(merge_node)); | ||||
| @@ -544,6 +569,13 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche | |||||
| [](std::pair<int, int> iter) { return iter.second; }); | [](std::pair<int, int> iter) { return iter.second; }); | ||||
| subgraph_outputs.assign(new_subgraph_outputs.begin(), new_subgraph_outputs.end()); | subgraph_outputs.assign(new_subgraph_outputs.begin(), new_subgraph_outputs.end()); | ||||
| // filter for -1 output index | |||||
| std::vector<uint32_t> new_partial_outputs; | |||||
| std::copy_if(partial_outputs.begin(), partial_outputs.end(), | |||||
| std::inserter(new_partial_outputs, new_partial_outputs.begin()), | |||||
| [](uint32_t output) { return output != UINT32_MAX; }); | |||||
| partial_node->outputIndex = new_partial_outputs; | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||