|
|
|
@@ -1445,6 +1445,20 @@ void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) { |
|
|
|
tuple_out_handle_cache_[node.get()] = tuple_items; |
|
|
|
} |
|
|
|
|
|
|
|
void DfGraphConvertor::ConvertTopK(const CNodePtr node) { |
|
|
|
MS_LOG(INFO) << "Convert TopK second input's type from int64 to int32."; |
|
|
|
auto value_ptr = node->input(2)->cast<ValueNodePtr>(); |
|
|
|
std::ostringstream ss; |
|
|
|
ss << "op" << value_ptr.get(); |
|
|
|
op_draw_name_[value_ptr.get()] = ss.str(); |
|
|
|
compute_sout_ << ss.str() << "[label= \"" << value_ptr->value()->ToString() << "\" shape=ellipse]" << endl; |
|
|
|
auto int64_value = value_ptr->value()->cast<Int64ImmPtr>()->value(); |
|
|
|
OpAdapterPtr adpt = FindAdapter(value_ptr, training_); |
|
|
|
auto op = adpt->generate(value_ptr); |
|
|
|
adpt->setAttr(op, "value", static_cast<int32_t>(int64_value)); |
|
|
|
op_cache_[value_ptr.get()] = op; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, uint64_t *index) { |
|
|
|
const int TUPLE_GET_ITEM_INDEX = 2; |
|
|
|
if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs |
|
|
|
@@ -1625,6 +1639,12 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// Convert TopK second input from int64 to int32. |
|
|
|
if (name == prim::kPrimTopK->name()) { |
|
|
|
ConvertTopK(node); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
// make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers |
|
|
|
if (name == prim::kPrimMakeTuple->name()) { |
|
|
|
ConvertMakeTuple(node); |
|
|
|
|