|
|
|
@@ -15,7 +15,9 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "backend/optimizer/ascend/mindir/dropout_unify_mindir.h" |
|
|
|
#include <ops/all_ops.h> |
|
|
|
#include <vector> |
|
|
|
#include <string> |
|
|
|
#include <memory> |
|
|
|
#include <numeric> |
|
|
|
#include <algorithm> |
|
|
|
@@ -23,45 +25,69 @@ |
|
|
|
#include "backend/session/anf_runtime_algorithm.h" |
|
|
|
#include "utils/log_adapter.h" |
|
|
|
|
|
|
|
/* |
|
|
|
DropoutGenMask: |
|
|
|
attr: seed0 seed1: |
|
|
|
input: 1.shape <>; |
|
|
|
2. keep_prob: type base on inputx type, if x in float/float16, then use this type, else use float16; |
|
|
|
output: shape: (count + 127) % 128 * 16 |
|
|
|
*/ |
|
|
|
namespace mindspore::opt { |
|
|
|
namespace { |
|
|
|
constexpr auto kKeepProb = "keep_prob"; |
|
|
|
constexpr auto kSeed0 = "Seed0"; |
|
|
|
constexpr auto kSeed1 = "Seed1"; |
|
|
|
constexpr auto kUint8BitSize = 8; |
|
|
|
|
|
|
|
namespace mindspore::opt { |
|
|
|
constexpr int64_t kMaskAlignNum = 128; |
|
|
|
constexpr int64_t kMaskMultiNum = 16; |
|
|
|
constexpr size_t kFloat16Len = 2; // size of float16 |
|
|
|
namespace { |
|
|
|
AnfNodePtr GetDropoutKeepProb(const AnfNodePtr &node, float *keep_prob) { |
|
|
|
MS_LOG(INFO) << "GetDropoutNodeInfo start."; |
|
|
|
constexpr size_t kInt64Len = 8; // size of int64 |
|
|
|
|
|
|
|
TypeId GetInputXDataType(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(keep_prob); |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (!AnfAlgo::HasNodeAttr(kKeepProb, cnode) || !AnfAlgo::HasNodeAttr(kSeed0, cnode) || |
|
|
|
!AnfAlgo::HasNodeAttr(kSeed1, cnode)) { |
|
|
|
MS_LOG(EXCEPTION) << "Dropout node does nothave attr: keep_prob or seed0 or seed1."; |
|
|
|
auto dropout_input_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); |
|
|
|
if (dropout_input_type != kNumberTypeFloat32 && dropout_input_type != kNumberTypeFloat && |
|
|
|
dropout_input_type != kNumberTypeFloat16) { |
|
|
|
dropout_input_type = kNumberTypeFloat16; |
|
|
|
} |
|
|
|
*keep_prob = AnfAlgo::GetNodeAttr<float>(node, kKeepProb); |
|
|
|
MS_LOG(INFO) << "keep_prob: " << *keep_prob; |
|
|
|
// return dropout input. maybe tensor or pre cnode output |
|
|
|
return cnode->input(1); |
|
|
|
MS_LOG(INFO) << "Dropout input data type: " << TypeIdLabel(dropout_input_type); |
|
|
|
return dropout_input_type; |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const float &keep_prob, const TypePtr &dtype) { |
|
|
|
MS_LOG(INFO) << "CreateKeepPorbValueNode start."; |
|
|
|
std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
std::vector<int64_t> shapes; |
|
|
|
auto shape_size_t = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); |
|
|
|
std::transform(shape_size_t.begin(), shape_size_t.end(), std::back_inserter(shapes), SizeToLong); |
|
|
|
return shapes; |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, TypeId type_id) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
// Step1: get keep_prob |
|
|
|
if (!AnfAlgo::HasNodeAttr(kKeepProb, cnode)) { |
|
|
|
MS_LOG(EXCEPTION) << "Dropout node does not have attr: keep_prob."; |
|
|
|
} |
|
|
|
if (AnfAlgo::GetCNodePrimitive(cnode)->ToString() == kDropoutOpName) { |
|
|
|
if (!AnfAlgo::HasNodeAttr(kSeed0, cnode) || !AnfAlgo::HasNodeAttr(kSeed1, cnode)) { |
|
|
|
MS_LOG(EXCEPTION) << "Dropout node does not have attr: seed0 or seed1."; |
|
|
|
} |
|
|
|
} |
|
|
|
auto keep_prob = AnfAlgo::GetNodeAttr<float>(node, kKeepProb); |
|
|
|
MS_LOG(INFO) << "Keep_prob value: " << keep_prob; |
|
|
|
|
|
|
|
std::vector<int64_t> keep_prob_shape = {}; |
|
|
|
ShapeVector shape = {}; |
|
|
|
auto keep_prob_tensor = std::make_shared<tensor::Tensor>(dtype->type_id(), keep_prob_shape); |
|
|
|
auto keep_prob_tensor = std::make_shared<tensor::Tensor>(type_id, keep_prob_shape); |
|
|
|
MS_EXCEPTION_IF_NULL(keep_prob_tensor); |
|
|
|
auto data_ptr = keep_prob_tensor->data_c(); |
|
|
|
MS_EXCEPTION_IF_NULL(data_ptr); |
|
|
|
// keep_prob's datatype is same with input data |
|
|
|
if (dtype->type_id() == kNumberTypeFloat16) { |
|
|
|
float16 half_data = float16(keep_prob); |
|
|
|
auto ret_code = memcpy_s(data_ptr, kFloat16Len, &half_data, kFloat16Len); |
|
|
|
if (type_id == kNumberTypeFloat16) { |
|
|
|
auto half_data = float16(keep_prob); |
|
|
|
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(keep_prob_tensor->data().nbytes()), &half_data, kFloat16Len); |
|
|
|
if (ret_code != 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; |
|
|
|
} |
|
|
|
@@ -69,59 +95,65 @@ ValueNodePtr CreateKeepPorbValueNode(const FuncGraphPtr &func_graph, const float |
|
|
|
auto *val = reinterpret_cast<float *>(data_ptr); |
|
|
|
*val = keep_prob; |
|
|
|
} |
|
|
|
auto abstract = std::make_shared<abstract::AbstractTensor>(dtype, shape); |
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type_id), keep_prob_shape); |
|
|
|
auto keep_prob_value = kernel_graph->NewValueNode(abstract, keep_prob_tensor); |
|
|
|
MS_EXCEPTION_IF_NULL(keep_prob_value); |
|
|
|
kernel_graph->AddValueNodeToGraph(keep_prob_value); |
|
|
|
return keep_prob_value; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int64_t> GetInputShape(const AnfNodePtr &node, const AnfNodePtr &dropout_input) { |
|
|
|
MS_LOG(INFO) << "GetInputShape start."; |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_input); |
|
|
|
std::vector<int64_t> shapes; |
|
|
|
if (dropout_input->isa<Parameter>()) { |
|
|
|
MS_LOG(INFO) << "Dropout input from parameter node."; |
|
|
|
// single test case |
|
|
|
auto dropout_input_value = dropout_input->cast<ParameterPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_input_value); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_input_value->Shape()); |
|
|
|
auto shape = dropout_input_value->Shape()->cast<abstract::ShapePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(shape); |
|
|
|
return shape->shape(); |
|
|
|
} else if (dropout_input->isa<CNode>()) { |
|
|
|
MS_LOG(INFO) << "Dropout input from cnode."; |
|
|
|
auto dropout_input_node = dropout_input->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_input_node); |
|
|
|
auto shape_size_t = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); |
|
|
|
std::transform(shape_size_t.begin(), shape_size_t.end(), std::back_inserter(shapes), SizeToLong); |
|
|
|
return shapes; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Dropout input is not parameter or cnode."; |
|
|
|
return {}; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape) { |
|
|
|
ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vector<int64_t> &shape, |
|
|
|
bool is_pynative = false) { |
|
|
|
MS_LOG(INFO) << "CreateShapeValueNode start."; |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
std::vector<ValuePtr> dim_values{}; |
|
|
|
abstract::AbstractBasePtrList abs{}; |
|
|
|
for (const auto &dim : shape) { |
|
|
|
dim_values.push_back(MakeValue(dim)); |
|
|
|
abs.push_back(std::make_shared<abstract::AbstractScalar>(dim)); |
|
|
|
ValuePtr shape_value = nullptr; |
|
|
|
AbstractBasePtr abstract = nullptr; |
|
|
|
if (is_pynative) { |
|
|
|
// pynative mode need to create tensor |
|
|
|
int64_t shape_dim = SizeToLong(shape.size()); |
|
|
|
std::vector<int64_t> shape_vec_shape = {shape_dim}; |
|
|
|
auto shape_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, shape_vec_shape); |
|
|
|
MS_EXCEPTION_IF_NULL(shape_tensor); |
|
|
|
auto data_ptr = shape_tensor->data_c(); |
|
|
|
MS_EXCEPTION_IF_NULL(data_ptr); |
|
|
|
auto elem_num = shape.size() * kInt64Len; |
|
|
|
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(shape_tensor->data().nbytes()), &shape[0], elem_num); |
|
|
|
if (ret_code != 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; |
|
|
|
} |
|
|
|
shape_value = shape_tensor; |
|
|
|
abstract = std::make_shared<abstract::AbstractTensor>(kInt64, shape_vec_shape); |
|
|
|
} else { |
|
|
|
std::vector<ValuePtr> dim_values{}; |
|
|
|
abstract::AbstractBasePtrList abs{}; |
|
|
|
for (const auto &dim : shape) { |
|
|
|
dim_values.push_back(MakeValue(dim)); |
|
|
|
abs.push_back(std::make_shared<abstract::AbstractScalar>(dim)); |
|
|
|
} |
|
|
|
shape_value = std::make_shared<ValueTuple>(dim_values); |
|
|
|
abstract = std::make_shared<abstract::AbstractTuple>(abs); |
|
|
|
} |
|
|
|
auto shape_value_tuple = std::make_shared<ValueTuple>(dim_values); |
|
|
|
MS_EXCEPTION_IF_NULL(shape_value_tuple); |
|
|
|
auto abstract = std::make_shared<abstract::AbstractTuple>(abs); |
|
|
|
MS_EXCEPTION_IF_NULL(abstract); |
|
|
|
auto shape_value = kernel_graph->NewValueNode(abstract, shape_value_tuple); |
|
|
|
MS_EXCEPTION_IF_NULL(shape_value); |
|
|
|
kernel_graph->AddValueNodeToGraph(shape_value); |
|
|
|
return shape_value; |
|
|
|
MS_EXCEPTION_IF_NULL(abstract); |
|
|
|
auto shape_value_node = kernel_graph->NewValueNode(abstract, shape_value); |
|
|
|
MS_EXCEPTION_IF_NULL(shape_value_node); |
|
|
|
kernel_graph->AddValueNodeToGraph(shape_value_node); |
|
|
|
return shape_value_node; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int64_t> CalDropoutGenMaskOutput(const std::vector<int64_t> &shape) { |
|
|
|
auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()); |
|
|
|
auto output_count = output_size / kMaskAlignNum; |
|
|
|
if (output_size % kMaskAlignNum != 0) { |
|
|
|
output_count++; |
|
|
|
} |
|
|
|
auto ret = output_count * kMaskMultiNum; |
|
|
|
MS_LOG(INFO) << "Output_size: " << ret; |
|
|
|
return {ret}; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
@@ -141,34 +173,34 @@ const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, con |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_cnode); |
|
|
|
auto dropout_node = tuple_cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_node); |
|
|
|
float keep_prob = 0; |
|
|
|
auto dropout_input = GetDropoutKeepProb(dropout_node, &keep_prob); |
|
|
|
auto dropout_dtype = AnfAlgo::GetOutputInferDataType(dropout_node, 0) == kNumberTypeFloat16 ? kFloat16 : kFloat32; |
|
|
|
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, keep_prob, dropout_dtype); |
|
|
|
auto shape = GetInputShape(dropout_node, dropout_input); |
|
|
|
auto shape_value = CreateShapeValueNode(func_graph, shape); |
|
|
|
|
|
|
|
auto inputx_type_id = GetInputXDataType(dropout_node); |
|
|
|
auto inputx_shape = GetInputXShape(dropout_node); |
|
|
|
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape); |
|
|
|
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id); |
|
|
|
|
|
|
|
// CreateDropoutGenMask |
|
|
|
auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()); |
|
|
|
output_size = output_size / kUint8BitSize; |
|
|
|
MS_LOG(INFO) << "Output_size: " << output_size; |
|
|
|
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)), |
|
|
|
shape_value, keep_prob_value}; |
|
|
|
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_gen_mask); |
|
|
|
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); |
|
|
|
ShapeVector dropout_gen_mask_output = {output_size}; |
|
|
|
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, dropout_gen_mask_output); |
|
|
|
auto output_shape = CalDropoutGenMaskOutput(inputx_shape); |
|
|
|
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape); |
|
|
|
MS_EXCEPTION_IF_NULL(gen_mask_abstract); |
|
|
|
dropout_gen_mask->set_abstract(gen_mask_abstract); |
|
|
|
dropout_gen_mask->set_scope(node->scope()); |
|
|
|
|
|
|
|
// CreateDropoutDoMask |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_node); |
|
|
|
auto dropout_cnode = dropout_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_cnode); |
|
|
|
auto dropout_input = dropout_cnode->input(1); |
|
|
|
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), |
|
|
|
dropout_input, dropout_gen_mask, keep_prob_value}; |
|
|
|
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_do_mask); |
|
|
|
ShapeVector dropout_do_mask_output = shape; |
|
|
|
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask_output); |
|
|
|
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape); |
|
|
|
dropout_do_mask->set_abstract(do_mask_abstract); |
|
|
|
dropout_do_mask->set_scope(node->scope()); |
|
|
|
|
|
|
|
@@ -178,8 +210,6 @@ const AnfNodePtr DropoutUnifyMindIR::Process(const FuncGraphPtr &func_graph, con |
|
|
|
const BaseRef DropoutGradUnifyMindIR::DefinePattern() const { |
|
|
|
VarPtr X = std::make_shared<Var>(); |
|
|
|
VarPtr Y = std::make_shared<Var>(); |
|
|
|
MS_EXCEPTION_IF_NULL(X); |
|
|
|
MS_EXCEPTION_IF_NULL(Y); |
|
|
|
auto dropout_prim = std::make_shared<Primitive>(kDropoutOpName); |
|
|
|
auto tuple_getitem_prim = prim::kPrimTupleGetItem; |
|
|
|
auto dropout_grad_prim = std::make_shared<Primitive>(kDropoutGradOpName); |
|
|
|
@@ -194,58 +224,74 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, |
|
|
|
const EquivPtr &equiv) const { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto dropout_grad = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_grad); |
|
|
|
auto tuple_getitem = dropout_grad->input(2); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem); |
|
|
|
auto tuple_getitem_cnode = tuple_getitem->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem_cnode); |
|
|
|
auto dropout_node = tuple_getitem_cnode->input(1); |
|
|
|
auto dropout_grad_cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_grad_cnode); |
|
|
|
auto getitem1_node = dropout_grad_cnode->input(2); |
|
|
|
MS_EXCEPTION_IF_NULL(getitem1_node); |
|
|
|
auto getitem1_cnode = getitem1_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(getitem1_cnode); |
|
|
|
auto dropout_node = getitem1_cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_node); |
|
|
|
float keep_prob = 0; |
|
|
|
auto dropout_input = GetDropoutKeepProb(dropout_node, &keep_prob); |
|
|
|
auto dropout_dtype = AnfAlgo::GetOutputInferDataType(dropout_node, 0) == kNumberTypeFloat16 ? kFloat16 : kFloat32; |
|
|
|
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, keep_prob, dropout_dtype); |
|
|
|
auto shape = GetInputShape(dropout_node, dropout_input); |
|
|
|
auto shape_value = CreateShapeValueNode(func_graph, shape); |
|
|
|
auto dropout_cnode = dropout_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_cnode); |
|
|
|
|
|
|
|
auto inputx_type_id = GetInputXDataType(dropout_node); |
|
|
|
auto inputx_shape = GetInputXShape(dropout_node); |
|
|
|
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape); |
|
|
|
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id); |
|
|
|
|
|
|
|
// CreateDropoutGenMask |
|
|
|
auto output_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()); |
|
|
|
output_size = output_size / kUint8BitSize; |
|
|
|
MS_LOG(INFO) << "Output_size: " << output_size; |
|
|
|
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)), |
|
|
|
shape_value, keep_prob_value}; |
|
|
|
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_gen_mask); |
|
|
|
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); |
|
|
|
ShapeVector dropout_gen_mask_output = {output_size}; |
|
|
|
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, dropout_gen_mask_output); |
|
|
|
auto output_shape = CalDropoutGenMaskOutput(inputx_shape); |
|
|
|
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape); |
|
|
|
MS_EXCEPTION_IF_NULL(gen_mask_abstract); |
|
|
|
dropout_gen_mask->set_abstract(gen_mask_abstract); |
|
|
|
dropout_gen_mask->set_scope(dropout_node->scope()); |
|
|
|
// AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); |
|
|
|
dropout_gen_mask->set_scope(node->scope()); |
|
|
|
|
|
|
|
// CreateDropoutDoMask-forward |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto &node_users = manager->node_users(); |
|
|
|
auto iter = node_users.find(dropout_node); |
|
|
|
CNodePtr dropout_do_mask1 = nullptr; |
|
|
|
if (iter != node_users.end()) { |
|
|
|
for (auto &node_index : iter->second) { |
|
|
|
// Dropout has two outputs, so output node is tuple_getitem |
|
|
|
auto tuple_getitem_cnode2 = node_index.first->cast<CNodePtr>(); |
|
|
|
// check if Dropout's first output, which is used by forward, is used. |
|
|
|
auto getitem_index = GetValue<int64_t>(tuple_getitem_cnode2->input(2)->cast<ValueNodePtr>()->value()); |
|
|
|
if (getitem_index == 0) { |
|
|
|
std::vector<AnfNodePtr> dropout_do_mask1_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), |
|
|
|
dropout_input, dropout_gen_mask, keep_prob_value}; |
|
|
|
auto dropout_do_mask1 = func_graph->NewCNode(dropout_do_mask1_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_do_mask1); |
|
|
|
ShapeVector dropout_do_mask1_output = shape; |
|
|
|
auto do_mask_abstract1 = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask1_output); |
|
|
|
dropout_do_mask1->set_abstract(do_mask_abstract1); |
|
|
|
dropout_do_mask1->set_scope(dropout_node->scope()); |
|
|
|
(void)manager->Replace(tuple_getitem_cnode2, dropout_do_mask1); |
|
|
|
break; |
|
|
|
auto used_node = node_index.first; |
|
|
|
if (AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimTupleGetItem)) { |
|
|
|
// check if Dropout's first output, which is used by forward, is used |
|
|
|
if (AnfAlgo::GetTupleGetItemOutIndex(used_node->cast<CNodePtr>()) == 0) { |
|
|
|
// if Dropout's first output is used, create forward DropoutDoMask |
|
|
|
auto dropout_input = dropout_cnode->input(1); |
|
|
|
std::vector<AnfNodePtr> dropout_do_mask1_inputs{ |
|
|
|
NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), dropout_input, dropout_gen_mask, |
|
|
|
keep_prob_value}; |
|
|
|
dropout_do_mask1 = func_graph->NewCNode(dropout_do_mask1_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_do_mask1); |
|
|
|
auto do_mask_abstract1 = |
|
|
|
std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape); |
|
|
|
dropout_do_mask1->set_abstract(do_mask_abstract1); |
|
|
|
dropout_do_mask1->set_scope(dropout_node->scope()); |
|
|
|
(void)manager->Replace(used_node, dropout_do_mask1); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (dropout_do_mask1 != nullptr) { |
|
|
|
// Dropout is used by ControlDepend in some situation, need to replace ControlDepend. |
|
|
|
auto &users = manager->node_users(); |
|
|
|
iter = users.find(dropout_node); |
|
|
|
if (iter != users.end()) { |
|
|
|
for (auto &node_index : iter->second) { |
|
|
|
auto used_node = node_index.first; |
|
|
|
if (AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimControlDepend)) { |
|
|
|
(void)manager->Replace(used_node, dropout_do_mask1); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -254,16 +300,112 @@ const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, |
|
|
|
if (equiv->find(grad_input_) == equiv->end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Can not find grad_input in this pattern."; |
|
|
|
} |
|
|
|
auto grad_input = utils::cast<AnfNodePtr>((*equiv)[grad_input_]); |
|
|
|
std::vector<AnfNodePtr> dropout_do_mask2_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), |
|
|
|
grad_input, dropout_gen_mask, keep_prob_value}; |
|
|
|
auto dropout_do_mask2 = func_graph->NewCNode(dropout_do_mask2_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_do_mask2); |
|
|
|
ShapeVector dropout_do_mask2_output = shape; |
|
|
|
auto do_mask_abstract2 = std::make_shared<abstract::AbstractTensor>(dropout_dtype, dropout_do_mask2_output); |
|
|
|
dropout_do_mask2->set_abstract(do_mask_abstract2); |
|
|
|
dropout_do_mask2->set_scope(node->scope()); |
|
|
|
|
|
|
|
return dropout_do_mask2; |
|
|
|
auto dropout_grad_input = utils::cast<AnfNodePtr>((*equiv)[grad_input_]); |
|
|
|
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), |
|
|
|
dropout_grad_input, dropout_gen_mask, keep_prob_value}; |
|
|
|
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_do_mask); |
|
|
|
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape); |
|
|
|
dropout_do_mask->set_abstract(do_mask_abstract); |
|
|
|
dropout_do_mask->set_scope(node->scope()); |
|
|
|
|
|
|
|
return dropout_do_mask; |
|
|
|
} |
|
|
|
|
|
|
|
const BaseRef DropoutUnifyMindIRPynative::DefinePattern() const { |
|
|
|
VarPtr X = std::make_shared<Var>(); |
|
|
|
VarPtr Y = std::make_shared<Var>(); |
|
|
|
VarPtr Z = std::make_shared<Var>(); |
|
|
|
auto dropout = VectorRef({prim::kPrimDropout, X}); |
|
|
|
auto getitem0 = VectorRef({prim::kPrimTupleGetItem, dropout, Y}); |
|
|
|
auto getitem1 = VectorRef({prim::kPrimTupleGetItem, dropout, Z}); |
|
|
|
auto maketuple = VectorRef({prim::kPrimMakeTuple, getitem0, getitem1}); |
|
|
|
return maketuple; |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr DropoutUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const EquivPtr &) const { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto maketuple_cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(maketuple_cnode); |
|
|
|
auto getitem0_node = maketuple_cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(getitem0_node); |
|
|
|
auto getitem1_node = maketuple_cnode->input(2); |
|
|
|
MS_EXCEPTION_IF_NULL(getitem1_node); |
|
|
|
auto getitem1_cnode = getitem1_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(getitem1_cnode); |
|
|
|
auto dropout_node = getitem1_cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_node); |
|
|
|
|
|
|
|
auto inputx_type_id = GetInputXDataType(dropout_node); |
|
|
|
auto inputx_shape = GetInputXShape(dropout_node); |
|
|
|
auto shape_value = CreateShapeValueNode(func_graph, inputx_shape, true); |
|
|
|
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_node, inputx_type_id); |
|
|
|
|
|
|
|
// CreateDropoutGenMask |
|
|
|
std::vector<AnfNodePtr> dropout_gen_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutGenMaskOpName)), |
|
|
|
shape_value, keep_prob_value}; |
|
|
|
CNodePtr dropout_gen_mask = func_graph->NewCNode(dropout_gen_mask_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_gen_mask); |
|
|
|
AnfAlgo::CopyNodeAttrs(node, dropout_gen_mask); |
|
|
|
auto output_shape = CalDropoutGenMaskOutput(inputx_shape); |
|
|
|
auto gen_mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, output_shape); |
|
|
|
MS_EXCEPTION_IF_NULL(gen_mask_abstract); |
|
|
|
dropout_gen_mask->set_abstract(gen_mask_abstract); |
|
|
|
dropout_gen_mask->set_scope(node->scope()); |
|
|
|
|
|
|
|
// CreateDropoutDoMask |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_node); |
|
|
|
auto dropout_cnode = dropout_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_cnode); |
|
|
|
auto dropout_input = dropout_cnode->input(1); |
|
|
|
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), |
|
|
|
dropout_input, dropout_gen_mask, keep_prob_value}; |
|
|
|
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_do_mask); |
|
|
|
auto do_mask_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(inputx_type_id), inputx_shape); |
|
|
|
dropout_do_mask->set_abstract(do_mask_abstract); |
|
|
|
dropout_do_mask->set_scope(node->scope()); |
|
|
|
|
|
|
|
// replace genmask and domask |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
(void)manager->Replace(getitem0_node, dropout_do_mask); |
|
|
|
(void)manager->Replace(getitem1_node, dropout_gen_mask); |
|
|
|
|
|
|
|
return node; |
|
|
|
} |
|
|
|
|
|
|
|
const BaseRef DropoutGradUnifyMindIRPynative::DefinePattern() const { |
|
|
|
VarPtr X = std::make_shared<Var>(); |
|
|
|
VarPtr Y = std::make_shared<Var>(); |
|
|
|
auto dropout_grad_prim = std::make_shared<Primitive>(kDropoutGradOpName); |
|
|
|
return VectorRef({dropout_grad_prim, X, Y}); |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr DropoutGradUnifyMindIRPynative::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const EquivPtr &) const { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto dropout_grad_cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_grad_cnode); |
|
|
|
|
|
|
|
auto grad_input_type_id = GetInputXDataType(dropout_grad_cnode); |
|
|
|
auto grad_input_shape = GetInputXShape(dropout_grad_cnode); |
|
|
|
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_grad_cnode, grad_input_type_id); |
|
|
|
|
|
|
|
// CreateDropoutDoMask |
|
|
|
auto grad_input = dropout_grad_cnode->input(1); |
|
|
|
auto mask_input = dropout_grad_cnode->input(2); |
|
|
|
std::vector<AnfNodePtr> dropout_do_mask_inputs{NewValueNode(std::make_shared<Primitive>(kDropoutDoMaskOpName)), |
|
|
|
grad_input, mask_input, keep_prob_value}; |
|
|
|
auto dropout_do_mask = func_graph->NewCNode(dropout_do_mask_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(dropout_do_mask); |
|
|
|
auto do_mask_abstract = |
|
|
|
std::make_shared<abstract::AbstractTensor>(TypeIdToType(grad_input_type_id), grad_input_shape); |
|
|
|
dropout_do_mask->set_abstract(do_mask_abstract); |
|
|
|
dropout_do_mask->set_scope(node->scope()); |
|
|
|
return dropout_do_mask; |
|
|
|
} |
|
|
|
} // namespace mindspore::opt |