|
|
|
@@ -383,6 +383,63 @@ ExecutorPy::~ExecutorPy() { |
|
|
|
ConfigManager::GetInstance().ResetConfig(); |
|
|
|
} |
|
|
|
|
|
|
|
void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weight_node, |
|
|
|
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> *fake_quant_table) { |
|
|
|
std::string weight_name; |
|
|
|
auto x = root_node->input(1); |
|
|
|
if (IsPrimitiveCNode(weight_node, prim::kPrimLoad)) { |
|
|
|
weight_name = weight_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name(); |
|
|
|
} else { |
|
|
|
weight_name = weight_node->cast<ParameterPtr>()->name(); |
|
|
|
} |
|
|
|
// find the fakequant from input |
|
|
|
int64_t count = 0; |
|
|
|
const int64_t max_depth = 5; |
|
|
|
CNodePtr cnode = nullptr; |
|
|
|
auto is_quant_cnode = [](const AnfNodePtr &node) { |
|
|
|
return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) || |
|
|
|
IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel); |
|
|
|
}; |
|
|
|
while (!is_quant_cnode(x)) { |
|
|
|
if (count >= max_depth) { |
|
|
|
break; |
|
|
|
} |
|
|
|
cnode = x->cast<CNodePtr>(); |
|
|
|
if (cnode == nullptr || cnode->size() <= 1) { |
|
|
|
break; |
|
|
|
} |
|
|
|
x = cnode->input(1); |
|
|
|
count += 1; |
|
|
|
} |
|
|
|
if (x->isa<Parameter>() || IsPrimitiveCNode(x, prim::kPrimLoad)) { |
|
|
|
(*fake_quant_table)[weight_name] = std::make_pair(nullptr, "input"); |
|
|
|
} |
|
|
|
// get the fakequant parameter minq's name |
|
|
|
if (!is_quant_cnode(x)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
cnode = x->cast<CNodePtr>(); |
|
|
|
if (cnode == nullptr || IsPrimitiveCNode(cnode, prim::kPrimLoad) || cnode->size() != 4) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto fakequant_min_node = cnode->input(2); |
|
|
|
if (!fakequant_min_node->isa<Parameter>() && !IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
std::string fakequant_min_node_name; |
|
|
|
if (IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) { |
|
|
|
fakequant_min_node_name = fakequant_min_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name(); |
|
|
|
} else { |
|
|
|
fakequant_min_node_name = fakequant_min_node->cast<ParameterPtr>()->name(); |
|
|
|
} |
|
|
|
auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value(); |
|
|
|
if (!quant_op_value->isa<PrimitivePy>()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto quant_op = quant_op_value->cast<PrimitivePyPtr>(); |
|
|
|
(*fake_quant_table)[weight_name] = std::make_pair(quant_op, fakequant_min_node_name); |
|
|
|
} |
|
|
|
|
|
|
|
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchInfoForQuantExport( |
|
|
|
const std::string &phase_s) { |
|
|
|
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); |
|
|
|
@@ -399,58 +456,21 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI |
|
|
|
IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel); |
|
|
|
}; |
|
|
|
for (const auto &node : nodes) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if (cnode == nullptr || cnode->size() != 3) { |
|
|
|
auto root_node = node->cast<CNodePtr>(); |
|
|
|
if (root_node == nullptr || root_node->size() != 3) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto x = cnode->input(1); |
|
|
|
auto weight = cnode->input(2); |
|
|
|
auto weight = root_node->input(2); |
|
|
|
if (!is_quant_cnode(weight)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// get parameter weight's name |
|
|
|
cnode = weight->cast<CNodePtr>(); |
|
|
|
auto cnode = weight->cast<CNodePtr>(); |
|
|
|
auto weight_node = cnode->input(2); |
|
|
|
if (!weight_node->isa<Parameter>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto weight_name = weight_node->cast<ParameterPtr>()->name(); |
|
|
|
// find the fakequant from input |
|
|
|
int64_t count = 0; |
|
|
|
const int64_t max_depth = 5; |
|
|
|
while (!is_quant_cnode(x)) { |
|
|
|
if (count >= max_depth) { |
|
|
|
break; |
|
|
|
} |
|
|
|
cnode = x->cast<CNodePtr>(); |
|
|
|
if (cnode == nullptr || cnode->size() <= 1) { |
|
|
|
break; |
|
|
|
} |
|
|
|
x = cnode->input(1); |
|
|
|
count += 1; |
|
|
|
} |
|
|
|
if (x->isa<Parameter>()) { |
|
|
|
fake_quant_table[weight_name] = std::make_pair(nullptr, "input"); |
|
|
|
} |
|
|
|
// get the fakequant parameter minq's name |
|
|
|
if (!is_quant_cnode(x)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
cnode = x->cast<CNodePtr>(); |
|
|
|
if (cnode == nullptr || cnode->size() != 4) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto fakequant_min_node = cnode->input(2); |
|
|
|
if (!fakequant_min_node->isa<Parameter>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto fakequant_min_node_name = fakequant_min_node->cast<ParameterPtr>()->name(); |
|
|
|
auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value(); |
|
|
|
if (!quant_op_value->isa<PrimitivePy>()) { |
|
|
|
if (!weight_node->isa<Parameter>() && !IsPrimitiveCNode(weight_node, prim::kPrimLoad)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto quant_op = quant_op_value->cast<PrimitivePyPtr>(); |
|
|
|
fake_quant_table[weight_name] = std::make_pair(quant_op, fakequant_min_node_name); |
|
|
|
GetWeightInfo(root_node, weight_node, &fake_quant_table); |
|
|
|
} |
|
|
|
|
|
|
|
return fake_quant_table; |
|
|
|
|