|
|
|
@@ -155,59 +155,20 @@ std::vector<int64_t> CalDropoutGenMaskOutput(const std::vector<int64_t> &shape) |
|
|
|
MS_LOG(INFO) << "Output_size: " << ret; |
|
|
|
return {ret}; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
const BaseRef DropoutUnifyMindIR::DefinePattern() const { |
|
|
|
VarPtr X = std::make_shared<Var>(); |
|
|
|
VarPtr Y = std::make_shared<Var>(); |
|
|
|
auto prim = std::make_shared<Primitive>(kDropoutOpName); |
|
|
|
auto ref = VectorRef({prim, X}); |
|
|
|
return VectorRef({prim::kPrimTupleGetItem, ref, Y}); |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const EquivPtr &) const { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto tuple_cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_cnode); |
|
|
|
auto dropout_node = tuple_cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_node); |
|
|
|
|
|
|
|
auto inputx_type_id = GetInputXDataType(dropout_node); |
|
|
|
auto inputx_shape = GetInputXShape(dropout_node); |
|
|
|
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape); |
|
|
|
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id); |
|
|
|
|
|
|
|
// CreateDropoutGenMask |
|
|
|
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)), |
|
|
|
shape_value, keep_prob_value}; |
|
|
|
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_gen_mask); |
|
|
|
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); |
|
|
|
auto output_shape = CalDropoutGenMaskOutput(inputx_shape); |
|
|
|
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape); |
|
|
|
MS_EXCEPTION_IF_NULL(gen_mask_abstract); |
|
|
|
dropout_gen_mask->set_abstract(gen_mask_abstract); |
|
|
|
dropout_gen_mask->set_scope(node->scope()); |
|
|
|
|
|
|
|
// CreateDropoutDoMask |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_node); |
|
|
|
auto dropout_cnode = dropout_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_cnode); |
|
|
|
auto dropout_input = dropout_cnode->input(1); |
|
|
|
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), |
|
|
|
dropout_input, dropout_gen_mask, keep_prob_value}; |
|
|
|
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_do_mask); |
|
|
|
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape); |
|
|
|
dropout_do_mask->set_abstract(do_mask_abstract); |
|
|
|
dropout_do_mask->set_scope(node->scope()); |
|
|
|
|
|
|
|
return dropout_do_mask; |
|
|
|
bool NeedUpdate(const CNodePtr &getitem_cnode) { |
|
|
|
MS_EXCEPTION_IF_NULL(getitem_cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(getitem_cnode->input(2)); |
|
|
|
auto index_vnode = getitem_cnode->input(2)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(index_vnode); |
|
|
|
auto index_value = index_vnode->value(); |
|
|
|
MS_EXCEPTION_IF_NULL(index_value); |
|
|
|
auto index = GetValue<int64_t>(index_value); |
|
|
|
return index == 1; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
const BaseRef DropoutGradUnifyMindIR::DefinePattern() const { |
|
|
|
const BaseRef DropoutAndDropoutGradUnifyMindIR::DefinePattern() const { |
|
|
|
VarPtr X = std::make_shared<Var>(); |
|
|
|
VarPtr Y = std::make_shared<Var>(); |
|
|
|
auto dropout_prim = std::make_shared<Primitive>(kDropoutOpName); |
|
|
|
@@ -220,8 +181,8 @@ const BaseRef DropoutGradUnifyMindIR::DefinePattern() const { |
|
|
|
return VectorRef({dropout_grad_prim, grad_input_, ref1}); |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const EquivPtr &equiv) const { |
|
|
|
const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const EquivPtr &equiv) const { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto dropout_grad_cnode = node->cast<CNodePtr>(); |
|
|
|
@@ -312,13 +273,74 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, |
|
|
|
return dropout_do_mask; |
|
|
|
} |
|
|
|
|
|
|
|
const BaseRef DropoutUnifyMindIRPynative::DefinePattern() const { |
|
|
|
const BaseRef DropoutUnifyMindIR0::DefinePattern() const { |
|
|
|
VarPtr X = std::make_shared<Var>(); |
|
|
|
VarPtr Y = std::make_shared<Var>(); |
|
|
|
auto prim = std::make_shared<Primitive>(kDropoutOpName); |
|
|
|
auto ref = VectorRef({prim, X}); |
|
|
|
return VectorRef({prim::kPrimTupleGetItem, ref, Y}); |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr DropoutUnifyMindIR0::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const EquivPtr &) const { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto tuple_cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_cnode); |
|
|
|
if (!NeedUpdate(tuple_cnode)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto dropout_node = tuple_cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_node); |
|
|
|
auto inputx_type_id = GetInputXDataType(dropout_node); |
|
|
|
auto inputx_shape = GetInputXShape(dropout_node); |
|
|
|
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape); |
|
|
|
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id); |
|
|
|
|
|
|
|
// CreateDropoutGenMask |
|
|
|
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)), |
|
|
|
shape_value, keep_prob_value}; |
|
|
|
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_gen_mask); |
|
|
|
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); |
|
|
|
auto output_shape = CalDropoutGenMaskOutput(inputx_shape); |
|
|
|
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape); |
|
|
|
MS_EXCEPTION_IF_NULL(gen_mask_abstract); |
|
|
|
dropout_gen_mask->set_abstract(gen_mask_abstract); |
|
|
|
dropout_gen_mask->set_scope(node->scope()); |
|
|
|
|
|
|
|
// CreateDropoutDoMask |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_node); |
|
|
|
auto dropout_cnode = dropout_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_cnode); |
|
|
|
auto dropout_input = dropout_cnode->input(1); |
|
|
|
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), |
|
|
|
dropout_input, dropout_gen_mask, keep_prob_value}; |
|
|
|
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_do_mask); |
|
|
|
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape); |
|
|
|
dropout_do_mask->set_abstract(do_mask_abstract); |
|
|
|
dropout_do_mask->set_scope(node->scope()); |
|
|
|
|
|
|
|
// make tuple to replace dropout |
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), dropout_do_mask, dropout_gen_mask}; |
|
|
|
auto make_tuple = func_graph->NewCNode(make_tuple_inputs); |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
(void)manager->Replace(dropout_node, make_tuple); |
|
|
|
|
|
|
|
tuple_cnode->set_abstract(gen_mask_abstract); |
|
|
|
return tuple_cnode; |
|
|
|
} |
|
|
|
|
|
|
|
const BaseRef DropoutUnifyMindIR1::DefinePattern() const { |
|
|
|
VarPtr X = std::make_shared<Var>(); |
|
|
|
return VectorRef({prim::kPrimDropout, X}); |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const EquivPtr &) const { |
|
|
|
const AnfNodePtr DropoutUnifyMindIR1::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const EquivPtr &) const { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto dropout_node = node->cast<CNodePtr>(); |
|
|
|
@@ -359,15 +381,15 @@ const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_gr |
|
|
|
return make_tuple; |
|
|
|
} |
|
|
|
|
|
|
|
const BaseRef DropoutGradUnifyMindIRPynative::DefinePattern() const { |
|
|
|
const BaseRef DropoutGradUnifyMindIR::DefinePattern() const { |
|
|
|
VarPtr X = std::make_shared<Var>(); |
|
|
|
VarPtr Y = std::make_shared<Var>(); |
|
|
|
auto dropout_grad_prim = std::make_shared<Primitive>(kDropoutGradOpName); |
|
|
|
return VectorRef({dropout_grad_prim, X, Y}); |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr DropoutGradUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const EquivPtr &) const { |
|
|
|
const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const EquivPtr &) const { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto dropout_grad_cnode = node->cast<CNodePtr>(); |
|
|
|
@@ -377,9 +399,25 @@ const AnfNodePtr DropoutGradUnifyMindIRPynative::Process(const FuncGraphPtr &fun |
|
|
|
auto grad_input_shape = GetInputXShape(dropout_grad_cnode); |
|
|
|
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_grad_cnode, grad_input_type_id); |
|
|
|
|
|
|
|
// DropoutGrad may not in the same graph with Dropout in heterogeneous scene, and mask input which is a parameter |
|
|
|
// in that scene, need to be updated. |
|
|
|
auto mask_input = dropout_grad_cnode->input(2); |
|
|
|
if (mask_input->isa<Parameter>()) { |
|
|
|
// update abstract |
|
|
|
auto mask_abstract = mask_input->abstract(); |
|
|
|
MS_EXCEPTION_IF_NULL(mask_abstract); |
|
|
|
auto mask_shape = CalDropoutGenMaskOutput(grad_input_shape); |
|
|
|
mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape); |
|
|
|
mask_input->set_abstract(mask_abstract); |
|
|
|
// update kernel info |
|
|
|
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); |
|
|
|
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT}); |
|
|
|
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeUInt8}); |
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), mask_input.get()); |
|
|
|
} |
|
|
|
|
|
|
|
// CreateDropoutDoMask |
|
|
|
auto grad_input = dropout_grad_cnode->input(1); |
|
|
|
auto mask_input = dropout_grad_cnode->input(2); |
|
|
|
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), |
|
|
|
grad_input, mask_input, keep_prob_value}; |
|
|
|
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); |
|
|
|
|