Browse Source

!14334 opt transdata

From: @lianliguang
Reviewed-by: @zhoufeng54,@kisnwang
Signed-off-by: @kisnwang
pull/14334/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
e408c9efc8
1 changed files with 10 additions and 8 deletions
  1. +10
    -8
      mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc

+ 10
- 8
mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc View File

@@ -354,13 +354,15 @@ void SetCastAndWeightFormat(const CNodePtr &kernel_node) {
} }


void SetWeightFormat(const AnfNodePtr &real_input_node, const std::vector<string> &output_format, void SetWeightFormat(const AnfNodePtr &real_input_node, const std::vector<string> &output_format,
const CNodePtr &kernel_node, size_t input_index) {
const CNodePtr &kernel_node, size_t input_index, bool force_fresh = false) {
if (real_input_node->isa<CNode>() || AnfAlgo::OutputAddrExist(real_input_node, 0)) {
return;
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// we set special device info of a input tensor. // we set special device info of a input tensor.
bool is_ref = false;
auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node); auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node);
if (op_info != nullptr) { if (op_info != nullptr) {
is_ref = op_info->is_ref();
force_fresh = op_info->is_ref() || force_fresh;
} }
auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node); auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
if (IsValueNode<tensor::Tensor>(real_input_node) && if (IsValueNode<tensor::Tensor>(real_input_node) &&
@@ -371,7 +373,7 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, const std::vector<string
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
return; return;
} }
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || force_fresh) {
builder->SetOutputsFormat(output_format); builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
builder->SetOutputsDeviceType(output_type); builder->SetOutputsDeviceType(output_type);
@@ -381,6 +383,9 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, const std::vector<string


bool RefreshCastAndParamWeightFormat(const AnfNodePtr &input_node, const string &format) { bool RefreshCastAndParamWeightFormat(const AnfNodePtr &input_node, const string &format) {
MS_EXCEPTION_IF_NULL(input_node); MS_EXCEPTION_IF_NULL(input_node);
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return false;
}
if (!input_node->isa<CNode>()) { if (!input_node->isa<CNode>()) {
return false; return false;
} }
@@ -397,7 +402,7 @@ bool RefreshCastAndParamWeightFormat(const AnfNodePtr &input_node, const string
info_builder->SetOutputsFormat({format}); info_builder->SetOutputsFormat({format});
AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), cast_node.get()); AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), cast_node.get());
auto cast_input_node = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cast_node, 0), 0); auto cast_input_node = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cast_node, 0), 0);
SetWeightFormat(cast_input_node.first, {format}, cast_node, 0);
SetWeightFormat(cast_input_node.first, {format}, cast_node, 0, true);
return true; return true;
} }
} // namespace } // namespace
@@ -418,9 +423,6 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) { if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
continue; continue;
} }
if (AnfAlgo::OutputAddrExist(real_input_node, 0)) {
continue;
}
auto refresh_format = selected_kernel_info->GetInputFormat(input_index); auto refresh_format = selected_kernel_info->GetInputFormat(input_index);
std::vector<std::string> output_format = {refresh_format}; std::vector<std::string> output_format = {refresh_format};
// if not find in host convert format map means the host has not registered the convert function of this format // if not find in host convert format map means the host has not registered the convert function of this format


Loading…
Cancel
Save