Browse Source

fix topk and trans_format_insert_pass

tags/v1.2.0-rc1
gongdaguo 5 years ago
parent
commit
3f42e6c659
2 changed files with 9 additions and 12 deletions
  1. +5
    -12
      mindspore/lite/nnacl/fp32/topk_fp32.c
  2. +4
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc

+ 5
- 12
mindspore/lite/nnacl/fp32/topk_fp32.c View File

@@ -31,16 +31,10 @@ int DescendCmp(const void *a, const void *b) {
}

int AscendCmp(const void *a, const void *b) {
float sub = ((const TopkNode *)a)->element - ((const TopkNode *)b)->element;
if (sub > 0) {
return 1;
} else if (sub < 0) {
return -1;
}
if (((const TopkNode *)a)->index > ((const TopkNode *)b)->index) {
return -1;
} else {
return 1;
} else {
return -1;
}
}

@@ -58,10 +52,9 @@ void Topk(float *input_data, float *output_data, int32_t *output_index, TopkPara
top_map[j].element = *(cur_input_data + j);
top_map[j].index = j;
}
if (parameter->sorted_) {
qsort(top_map, last_dim_size, sizeof(top_map[0]), DescendCmp);
} else {
qsort(top_map, last_dim_size, sizeof(top_map[0]), AscendCmp);
qsort(top_map, last_dim_size, sizeof(top_map[0]), DescendCmp);
if (!parameter->sorted_) {
qsort(top_map, k, sizeof(top_map[0]), AscendCmp);
}
for (int m = 0; m < k; m++) {
cur_output_data[m] = top_map[m].element;


+ 4
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc View File

@@ -206,6 +206,10 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
continue;
}
#endif
auto &input_tensor = graph->allTensors.at((*iter)->inputIndex[i]);
if (input_tensor->nodeType == NodeType_ValueNode && input_tensor->dims.size() < 4) {
continue;
}
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed";


Loading…
Cancel
Save