|
|
@@ -19,12 +19,38 @@ |
|
|
#include "tools/optimizer/common/gllo_utils.h" |
|
|
#include "tools/optimizer/common/gllo_utils.h" |
|
|
#include "schema/inner/model_generated.h" |
|
|
#include "schema/inner/model_generated.h" |
|
|
#include "tools/converter/quantizer/quant_cast.h" |
|
|
#include "tools/converter/quantizer/quant_cast.h" |
|
|
|
|
|
#include "src/common/utils.h" |
|
|
|
|
|
|
|
|
using mindspore::lite::PrimitiveC; |
|
|
using mindspore::lite::PrimitiveC; |
|
|
namespace mindspore::opt { |
|
|
namespace mindspore::opt { |
|
|
namespace { |
|
|
namespace { |
|
|
constexpr size_t split_inputs_size = 3; |
|
|
constexpr size_t split_inputs_size = 3; |
|
|
|
|
|
const std::vector<schema::PrimitiveType> single_input_ops = { |
|
|
|
|
|
schema::PrimitiveType_Reduce, schema::PrimitiveType_ArgMin, schema::PrimitiveType_ArgMax, |
|
|
|
|
|
schema::PrimitiveType_SpaceToBatch, schema::PrimitiveType_BatchToSpace, schema::PrimitiveType_SpaceToBatchND, |
|
|
|
|
|
schema::PrimitiveType_BatchToSpaceND, schema::PrimitiveType_SpaceToDepth}; |
|
|
} // namespace |
|
|
} // namespace |
|
|
|
|
|
|
|
|
|
|
|
STATUS ReorderCnodeInputs(CNode *cnode, const std::vector<size_t> &perm) { |
|
|
|
|
|
// add primitive first |
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs = {cnode->input(0)}; |
|
|
|
|
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); |
|
|
|
|
|
auto old_quant_params = primitive_c->input_quant_params(); |
|
|
|
|
|
std::vector<std::vector<schema::QuantParamT>> new_quant_params; |
|
|
|
|
|
// add inputs as perm order |
|
|
|
|
|
for (size_t idx : perm) { |
|
|
|
|
|
if (idx > cnode->inputs().size() - 1) { |
|
|
|
|
|
MS_LOG(ERROR) << "Idx " << idx << " is larger than inputs size: " << cnode->inputs().size() - 1; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
new_inputs.emplace_back(cnode->input(idx)); |
|
|
|
|
|
new_quant_params.emplace_back(old_quant_params.at(idx - 1)); |
|
|
|
|
|
} |
|
|
|
|
|
cnode->set_inputs(new_inputs); |
|
|
|
|
|
primitive_c->set_input_quant_params(new_quant_params); |
|
|
|
|
|
return RET_OK; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
bool TfliteInputsOrderExchangePass::Run(const FuncGraphPtr &graph) { |
|
|
bool TfliteInputsOrderExchangePass::Run(const FuncGraphPtr &graph) { |
|
|
auto node_list = TopoSort(graph->get_return()); |
|
|
auto node_list = TopoSort(graph->get_return()); |
|
|
for (auto &node : node_list) { |
|
|
for (auto &node : node_list) { |
|
|
@@ -33,50 +59,46 @@ bool TfliteInputsOrderExchangePass::Run(const FuncGraphPtr &graph) { |
|
|
} |
|
|
} |
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); |
|
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); |
|
|
if (opt::GetCNodeType(node) == schema::PrimitiveType_DeConv2D) { |
|
|
|
|
|
cnode->set_input(1, cnode->input(3)); |
|
|
|
|
|
auto inputs = cnode->inputs(); |
|
|
|
|
|
inputs.pop_back(); |
|
|
|
|
|
cnode->set_inputs(inputs); |
|
|
|
|
|
|
|
|
|
|
|
auto input_quant_params = primitive_c->input_quant_params(); |
|
|
|
|
|
input_quant_params[0] = input_quant_params.at(2); |
|
|
|
|
|
input_quant_params.pop_back(); |
|
|
|
|
|
primitive_c->set_input_quant_params(input_quant_params); |
|
|
|
|
|
|
|
|
if (opt::GetCNodeType(node) == schema::PrimitiveType_Fill) { |
|
|
|
|
|
// dims, value => value, dims |
|
|
|
|
|
if (RET_OK != ReorderCnodeInputs(cnode.get(), {2, 1})) { |
|
|
|
|
|
MS_LOG(ERROR) << "Reorder fill inputs failed"; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (opt::GetCNodeType(node) == schema::PrimitiveType_Split && cnode->inputs().size() == split_inputs_size) { |
|
|
|
|
|
cnode->set_input(1, cnode->input(2)); |
|
|
|
|
|
auto inputs = cnode->inputs(); |
|
|
|
|
|
inputs.pop_back(); |
|
|
|
|
|
cnode->set_inputs(inputs); |
|
|
|
|
|
|
|
|
if (opt::GetCNodeType(node) == schema::PrimitiveType_DeConv2D) { |
|
|
|
|
|
// output_shape, weights, input => input, weight |
|
|
|
|
|
if (RET_OK != ReorderCnodeInputs(cnode.get(), {3, 2})) { |
|
|
|
|
|
MS_LOG(ERROR) << "Reorder deconv inputs failed"; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
auto input_quant_params = primitive_c->input_quant_params(); |
|
|
|
|
|
input_quant_params[0] = input_quant_params.at(1); |
|
|
|
|
|
input_quant_params.pop_back(); |
|
|
|
|
|
primitive_c->set_input_quant_params(input_quant_params); |
|
|
|
|
|
|
|
|
if (opt::GetCNodeType(node) == schema::PrimitiveType_Split && cnode->inputs().size() == split_inputs_size) { |
|
|
|
|
|
// axis, input, ??? => input, axis |
|
|
|
|
|
if (RET_OK != ReorderCnodeInputs(cnode.get(), {2, 1})) { |
|
|
|
|
|
MS_LOG(ERROR) << "Reorder split inputs failed"; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (opt::GetCNodeType(node) == schema::PrimitiveType_Reduce || |
|
|
|
|
|
opt::GetCNodeType(node) == schema::PrimitiveType_ArgMin || |
|
|
|
|
|
opt::GetCNodeType(node) == schema::PrimitiveType_ArgMax || |
|
|
|
|
|
opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToBatch || |
|
|
|
|
|
opt::GetCNodeType(node) == schema::PrimitiveType_BatchToSpace || |
|
|
|
|
|
opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToBatchND || |
|
|
|
|
|
opt::GetCNodeType(node) == schema::PrimitiveType_BatchToSpaceND || |
|
|
|
|
|
opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToDepth || |
|
|
|
|
|
(opt::GetCNodeType(node) == schema::PrimitiveType_Pad && primitive_c->primitiveT()->value.AsPad() != nullptr && |
|
|
|
|
|
primitive_c->primitiveT()->value.AsPad()->paddingMode == schema::PaddingMode_CONSTANT) || |
|
|
|
|
|
(opt::GetCNodeType(node) == schema::PrimitiveType_Resize && |
|
|
|
|
|
primitive_c->primitiveT()->value.AsResize() != nullptr && |
|
|
|
|
|
primitive_c->primitiveT()->value.AsResize()->newHeight != 0 && |
|
|
|
|
|
primitive_c->primitiveT()->value.AsResize()->newWidth != 0)) { |
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs; |
|
|
|
|
|
new_inputs.emplace_back(cnode->input(0)); |
|
|
|
|
|
new_inputs.emplace_back(cnode->input(1)); |
|
|
|
|
|
cnode->set_inputs(new_inputs); |
|
|
|
|
|
|
|
|
bool is_single_input_pad = opt::GetCNodeType(node) == schema::PrimitiveType_Pad && |
|
|
|
|
|
primitive_c->primitiveT()->value.AsPad() != nullptr && |
|
|
|
|
|
primitive_c->primitiveT()->value.AsPad()->paddingMode == schema::PaddingMode_CONSTANT; |
|
|
|
|
|
bool is_single_input_resize = opt::GetCNodeType(node) == schema::PrimitiveType_Resize && |
|
|
|
|
|
primitive_c->primitiveT()->value.AsResize() != nullptr && |
|
|
|
|
|
primitive_c->primitiveT()->value.AsResize()->newHeight != 0 && |
|
|
|
|
|
primitive_c->primitiveT()->value.AsResize()->newWidth != 0; |
|
|
|
|
|
if (lite::IsContain(single_input_ops, opt::GetCNodeType(node)) || is_single_input_pad || is_single_input_resize) { |
|
|
|
|
|
if (RET_OK != ReorderCnodeInputs(cnode.get(), {1})) { |
|
|
|
|
|
MS_LOG(ERROR) << "Reorder single input failed"; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|