|
|
|
@@ -455,6 +455,9 @@ Status GatherV2PInfo::InferForwardCommunication() { |
|
|
|
MS_LOG(ERROR) << name_ << ": Infer Group failed."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (group_.name().empty()) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
attr_group = std::make_pair(GROUP, MakeValue(group_.name())); |
|
|
|
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); |
|
|
|
OperatorAttrs attrs = {attr_op, attr_group}; |
|
|
|
@@ -472,7 +475,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { |
|
|
|
MS_LOG(ERROR) << "GenerateGraph Init failed"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (manual_split_) { |
|
|
|
if (manual_split_ && target_ != CPU) { |
|
|
|
if (InferOffset() != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Infer Bias failed."; |
|
|
|
return FAILED; |
|
|
|
@@ -519,7 +522,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { |
|
|
|
} |
|
|
|
|
|
|
|
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { |
|
|
|
if (manual_split_) { |
|
|
|
if (manual_split_ && target_ != CPU) { |
|
|
|
if (ComputeReplaceGraph(cnode) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; |
|
|
|
return nullptr; |
|
|
|
@@ -540,13 +543,24 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { |
|
|
|
} |
|
|
|
|
|
|
|
Status GatherV2PInfo::ComputeReplaceOp() { |
|
|
|
if (InferBias() != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Infer offset failed."; |
|
|
|
return FAILED; |
|
|
|
int32_t bias = 0; |
|
|
|
if (manual_split_) { |
|
|
|
if (InferOffset() != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Infer offset failed."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
bias = index_offset_; |
|
|
|
} else { |
|
|
|
if (InferBias() != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Infer offset failed."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
bias = bias_; |
|
|
|
} |
|
|
|
|
|
|
|
OperatorName op_name = EMBEDDING_LOOKUP; |
|
|
|
OperatorAttrs attrs; |
|
|
|
Attr param_offset = std::make_pair("offset", MakeValue(bias_)); |
|
|
|
Attr param_offset = std::make_pair("offset", MakeValue(bias)); |
|
|
|
OperatorParams params = {std::make_pair(param_offset, 3)}; |
|
|
|
OperatorArgs args = std::make_pair(attrs, params); |
|
|
|
Operator op = std::make_pair(op_name, args); |
|
|
|
|