Browse Source

optimizer concat output pass

tags/v1.2.0-rc1
jjfeing 4 years ago
parent
commit
28d92b1e89
4 changed files with 66 additions and 21 deletions
  1. +20
    -12
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc
  2. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h
  3. +42
    -6
      mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc
  4. +2
    -2
      mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h

+ 20
- 12
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc View File

@@ -82,13 +82,7 @@ void TbeKernelSelect::TbeMetadataInfoEx() {
}

void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
// get dynamic inputs
auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_);
MS_EXCEPTION_IF_NULL(primitive);
std::vector<int64_t> dyn_input_sizes;
if (primitive->HasAttr(kAttrDynInputSizes)) {
dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrDynInputSizes));
}
auto dyn_input_sizes = GetNodeDynamicInputs();
// get real input/output num
size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_);
const auto inputs_info = op_info.inputs_ptr();
@@ -189,8 +183,9 @@ void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) {
return;
}
std::vector<std::shared_ptr<KernelBuildInfo>> new_kernel_info_list;
auto dynamic_inputs = GetNodeDynamicInputs();
for (auto iter = kernel_info_list_->begin(); iter != kernel_info_list_->end(); ++iter) {
if (!FilterInVaildShape(iter)) {
if (!FilterInVaildShape(iter, !dynamic_inputs.empty())) {
continue;
}
if (op_info.need_check_supported()) {
@@ -203,13 +198,15 @@ void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) {
(*kernel_info_list_) = new_kernel_info_list;
}

bool TbeKernelSelect::FilterInVaildShape(
const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) {
bool TbeKernelSelect::FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter, bool is_dynamic_input) {
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
const auto &kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats();
for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) {
// dynamic input just need to check first input, because other inputs copy from 1th input;
auto iter_num =
is_dynamic_input && !kernel_build_info_inputs_format.empty() ? 1 : kernel_build_info_inputs_format.size();
for (size_t i = 0; i < iter_num; ++i) {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
const auto &format = kernel_build_info_inputs_format[i];
const auto &format = kernel_build_info_inputs_format.at(i);
if (!IsShapeMatchFormat(shape, format)) {
return false;
}
@@ -279,6 +276,17 @@ void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_
builder->SetKernelType(TBE_KERNEL);
}

std::vector<int64_t> TbeKernelSelect::GetNodeDynamicInputs() {
// get dynamic inputs
auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_);
MS_EXCEPTION_IF_NULL(primitive);
std::vector<int64_t> dyn_input_sizes;
if (primitive->HasAttr(kAttrDynInputSizes)) {
dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrDynInputSizes));
}
return dyn_input_sizes;
}

bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num,
const std::vector<std::shared_ptr<OpIOInfo>> &ios_info,
const std::vector<int64_t> &dyn_input_sizes, std::vector<std::string> *formats,


+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h View File

@@ -44,10 +44,11 @@ class TbeKernelSelect {
void GetBroadcastPatternKernelInfo(const OpInfo &op_info);
void GetReducePatternKernelInfo(const OpInfo &op_info);
void FilterInVaildKernelInfo(const OpInfo &op_info);
bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter);
bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter, bool is_dynamic_input);
static bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format);
bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter);
static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder);
std::vector<int64_t> GetNodeDynamicInputs();
bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num,
const std::vector<std::shared_ptr<OpIOInfo>> &ios_info,
const std::vector<int64_t> &dyn_input_sizes, std::vector<std::string> *formats,


+ 42
- 6
mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc View File

@@ -18,11 +18,47 @@
#include <string>
#include "backend/session/anf_runtime_algorithm.h"

namespace mindspore {
namespace opt {
namespace mindspore::opt {
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat) {
MS_EXCEPTION_IF_NULL(concat);
std::vector<std::string> inputs_device_format;
std::vector<std::string> outputs_device_format;
std::vector<TypeId> inputs_device_type;
std::vector<TypeId> outputs_device_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(concat); ++input_index) {
inputs_device_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(concat, input_index));
inputs_device_type.emplace_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(concat, input_index));
}
// Current only support default format & float16
auto cmp_format = inputs_device_format.begin();
auto format_iter = std::find_if(inputs_device_format.begin(), inputs_device_format.end(),
[&](const auto &format) { return format != (*cmp_format); });
if (format_iter != inputs_device_format.end()) {
MS_LOG(EXCEPTION) << "Input format is not same, value: " << *format_iter;
}
auto cmp_dtype = inputs_device_type.begin();
auto dtype_iter = std::find_if(inputs_device_type.begin(), inputs_device_type.end(),
[&](const auto &dtype) { return dtype != (*cmp_dtype); });
if (dtype_iter != inputs_device_type.end()) {
MS_LOG(EXCEPTION) << "Input dtype is not same, value: " << *dtype_iter;
}
outputs_device_format.emplace_back(*cmp_format);
outputs_device_type.emplace_back(*cmp_dtype);

builder.SetFusionType(kernel::FusionType::OPAQUE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(TBE_KERNEL);
builder.SetInputsFormat(inputs_device_format);
builder.SetOutputsFormat(outputs_device_format);
builder.SetInputsDeviceType(inputs_device_type);
builder.SetOutputsDeviceType(outputs_device_type);
return builder.Build();
}

AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::vector<AnfNodePtr> &new_tuple_getitems,
int64_t rank_size) const {
int64_t rank_size) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
@@ -43,7 +79,8 @@ AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat);
std::vector<int64_t> dyn_input_size{rank_size};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat);
kernel_select_->SelectKernel(concat);
auto kernel_build_info = GenerateKernelBuildInfo(concat);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, concat.get());
make_tuple_inputs.push_back(concat);
}

@@ -78,5 +115,4 @@ const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_gra
CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs);
return InsertConcatForOutput(func_graph, node, new_outputs, rank_size);
}
} // namespace opt
} // namespace mindspore
} // namespace mindspore::opt

+ 2
- 2
mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h View File

@@ -33,8 +33,8 @@ class ConcatOutputsForAllGather : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

private:
AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::vector<AnfNodePtr> &new_tuple_getitems, int64_t rank_size) const;
static AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::vector<AnfNodePtr> &new_tuple_getitems, int64_t rank_size);
KernelSelectPtr kernel_select_;
};
} // namespace opt


Loading…
Cancel
Save