Browse Source

add attr for transdata node

tags/v1.0.0
WilliamLian 5 years ago
parent
commit
097f53bed9
8 changed files with 31 additions and 23 deletions
  1. +8
    -4
      mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc
  2. +2
    -0
      mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h
  3. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc
  4. +0
    -16
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc
  5. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h
  6. +0
    -1
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc
  7. +17
    -0
      mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
  8. +1
    -0
      mindspore/ccsrc/utils/utils.h

+ 8
- 4
mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc View File

@@ -108,10 +108,7 @@ std::string KernelBuildInfo::ToString() const {
return output_buffer.str();
}

bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
if (kernel_type_ != other.kernel_type_ || fusion_type_ != other.fusion_type_ || processor_ != other.processor_) {
return false;
}
bool KernelBuildInfo::IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const {
if (inputs_format_ != other.inputs_format_ || outputs_format_ != other.outputs_format_) {
if (op_pattern_ != kFormatAgnosticPattern) {
return false;
@@ -123,6 +120,13 @@ bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_);
}

bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
if (kernel_type_ != other.kernel_type_ || fusion_type_ != other.fusion_type_ || processor_ != other.processor_) {
return false;
}
return IsSimilarityKernelBuildInfo(other);
}

bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_.empty(); }

bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); }


+ 2
- 0
mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h View File

@@ -91,6 +91,8 @@ class KernelBuildInfo {

std::string ToString() const;

bool IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const;

bool operator==(const KernelBuildInfo &other) const;

bool operator!=(const KernelBuildInfo &other) const;


+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc View File

@@ -130,6 +130,7 @@ void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel:
AicpuMetadataInfo(kernel_node, kernel_info_list);
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
}

bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(select_kernel_build_info);
@@ -140,7 +141,7 @@ bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
MS_EXCEPTION_IF_NULL(item);
return *item == *select_kernel_build_info;
return item->IsSimilarityKernelBuildInfo(*select_kernel_build_info);
});
}



+ 0
- 16
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc View File

@@ -178,22 +178,6 @@ void TbeAdapter::NormalizeFuncName(std::string *func_name) {
}
}

void TbeAdapter::SetTbeAttrsForTransDataOp(const mindspore::AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
if (AnfAlgo::GetCNodeName(anf_node) == kTransDataOpName) {
std::string input_format = AnfAlgo::GetInputFormat(anf_node, 0);
std::string output_format = AnfAlgo::GetOutputFormat(anf_node, 0);
if (input_format == kOpFormat_DEFAULT) {
input_format = kOpFormat_NCHW;
}
if (output_format == kOpFormat_DEFAULT) {
output_format = kOpFormat_NCHW;
}
AnfAlgo::SetNodeAttr("src_format", MakeValue(input_format), anf_node);
AnfAlgo::SetNodeAttr("dst_format", MakeValue(output_format), anf_node);
}
}

std::unordered_set<std::string> input_order_adjusted_ops = {
"Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop",
"LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"};


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h View File

@@ -36,7 +36,7 @@ class TbeAdapter {
TbeAdapter() = default;
~TbeAdapter() = default;
static void NormalizeFuncName(std::string *func_name);
static void SetTbeAttrsForTransDataOp(const AnfNodePtr &anf_node);
static void InputOrderPass(const std::string &op_name, std::vector<std::vector<nlohmann::json>> const &inputs_list,
nlohmann::json *inputs_json);
static bool RunAttrPass(const AnfNodePtr &anf_node, const std::vector<std::shared_ptr<OpAttr>> &op_info_attrs,


+ 0
- 1
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc View File

@@ -75,7 +75,6 @@ bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
set<std::string> processed_kernel;
for (const auto &anf_node : anf_nodes) {
// gen kernel json
tbe::TbeAdapter::SetTbeAttrsForTransDataOp(anf_node);
if (AnfAlgo::GetKernelMod(anf_node) != nullptr) {
continue;
}


+ 17
- 0
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc View File

@@ -48,6 +48,22 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i
return reshape;
}

void SetTransNodeAttr(const CNodePtr &trans_node) {
MS_EXCEPTION_IF_NULL(trans_node);
if (AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName) {
std::string input_format = AnfAlgo::GetInputFormat(trans_node, 0);
std::string output_format = AnfAlgo::GetOutputFormat(trans_node, 0);
if (input_format == kOpFormat_DEFAULT) {
input_format = kOpFormat_NCHW;
}
if (output_format == kOpFormat_DEFAULT) {
output_format = kOpFormat_NCHW;
}
AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(input_format), trans_node);
AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(output_format), trans_node);
}
}

AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
AnfNodePtr trans_node = nullptr;
@@ -173,6 +189,7 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
builder->SetInputsDeviceType({type_id});
}
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
SetTransNodeAttr(trans_data->cast<CNodePtr>());
}

CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,


+ 1
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -224,6 +224,7 @@ constexpr auto kAttrEventId = "event_id";
constexpr auto kAttrDynInput = "dynamic";
constexpr auto kAttrDynInputSizes = "dyn_input_sizes";
constexpr auto kAttrSrcFormat = "src_format";
constexpr auto kAttrDstFormat = "dst_format";
constexpr auto kAttrMultiples = "multiples";
constexpr auto kAttrFixPrecision = "fix_precision";
constexpr auto kAttrOutputPrecision = "output_precision";


Loading…
Cancel
Save