Browse Source

syn code

tags/v0.3.0-alpha
chang zherui 5 years ago
parent
commit
87f7488e50
2 changed files with 11 additions and 13 deletions
  1. +9
    -5
      mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
  2. +2
    -8
      mindspore/ccsrc/transform/convert.cc

+ 9
- 5
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h View File

@@ -248,17 +248,18 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
if (addn->size() != 2) { if (addn->size() != 2) {
return nullptr; return nullptr;
} }

AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
if (x_ == nullptr || y_ == nullptr || z_ == nullptr) { if (x_ == nullptr || y_ == nullptr || z_ == nullptr) {
return nullptr; return nullptr;
} }


auto addn_op_node = addn->input(0);
auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0);
auto fg = node->func_graph(); auto fg = node->func_graph();
AnfNodePtr tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), z_, x_}, fg);
AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimAddN), tuple}, fg);
AnfNodePtr all_reduce = NewCNode({NewValueNode(prim::kPrimAllReduce), add}, fg);
return NewCNode({NewValueNode(prim::kPrimMul), all_reduce, y_}, fg);
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg);
return NewCNode({mul_, all_reduce, y_}, fg);
} }


void Visit(const AnfNodePtr &node) override { void Visit(const AnfNodePtr &node) override {
@@ -269,6 +270,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
AnfVisitor::Match(prim::kPrimMul)(node); AnfVisitor::Match(prim::kPrimMul)(node);
level_ = 0; level_ = 0;
if (is_reduce_match_) { if (is_reduce_match_) {
mul_ = node->cast<CNodePtr>()->input(0);
y_ = tmp_; y_ = tmp_;
} else { } else {
z_ = node; z_ = node;
@@ -280,6 +282,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if (cnode->size() > 1) { if (cnode->size() > 1) {
all_reduce_ = cnode->input(0);
x_ = cnode->input(1); x_ = cnode->input(1);
is_reduce_match_ = true; is_reduce_match_ = true;
} }
@@ -302,6 +305,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
int level_{0}; int level_{0};
bool is_reduce_match_{false}; bool is_reduce_match_{false};
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
AnfNodePtr all_reduce_{nullptr}, mul_{nullptr};
}; };


class ArithmeticSimplify { class ArithmeticSimplify {


+ 2
- 8
mindspore/ccsrc/transform/convert.cc View File

@@ -96,7 +96,6 @@ const char kNameConfusionMatrix[] = "ConfusionMatrix";
const char kNameResizeNearestNeighborD[] = "ResizeNearestNeighbor"; const char kNameResizeNearestNeighborD[] = "ResizeNearestNeighbor";
const char kNameResizeNearestNeighborGrad[] = "ResizeNearestNeighborGrad"; const char kNameResizeNearestNeighborGrad[] = "ResizeNearestNeighborGrad";
const char kNameApplyAdam[] = "Adam"; const char kNameApplyAdam[] = "Adam";
const char kNameExtractImagePatches[] = "ExtractImagePatches";
const char kNameReLU6[] = "ReLU6"; const char kNameReLU6[] = "ReLU6";
const char kNameReLU6Grad[] = "ReLU6Grad"; const char kNameReLU6Grad[] = "ReLU6Grad";
const char kNameElu[] = "Elu"; const char kNameElu[] = "Elu";
@@ -111,8 +110,6 @@ const char kNameSigmoidCrossEntropyWithLogits[] = "SigmoidCrossEntropyWithLogits
const char kNameSigmoidCrossEntropyWithLogitsGrad[] = "SigmoidCrossEntropyWithLogitsGrad"; const char kNameSigmoidCrossEntropyWithLogitsGrad[] = "SigmoidCrossEntropyWithLogitsGrad";
const char kNameScatterNdD[] = "ScatterNd"; const char kNameScatterNdD[] = "ScatterNd";
const char kNamePadD[] = "Pad"; const char kNamePadD[] = "Pad";
const char kNameMirrorPad[] = "MirrorPad";
const char kNameMirrorPadGrad[] = "MirrorPadGrad";
const char kNameGatherNd[] = "GatherNd"; const char kNameGatherNd[] = "GatherNd";
const char kNameArgmax[] = "Argmax"; const char kNameArgmax[] = "Argmax";
const char kNameArgmin[] = "Argmin"; const char kNameArgmin[] = "Argmin";
@@ -216,7 +213,6 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)}, {string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)},
{string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)}, {string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)},
{string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)}, {string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)},
{string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)},
{prim::kPrimAssign->name(), ADPT_DESC(Assign)}, {prim::kPrimAssign->name(), ADPT_DESC(Assign)},
{prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)}, {prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)},
{prim::kPrimReluGrad->name(), ADPT_DESC(ReluGrad)}, {prim::kPrimReluGrad->name(), ADPT_DESC(ReluGrad)},
@@ -261,8 +257,6 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameSigmoidCrossEntropyWithLogitsGrad), ADPT_DESC(SigmoidCrossEntropyWithLogitsGrad)}, {string(kNameSigmoidCrossEntropyWithLogitsGrad), ADPT_DESC(SigmoidCrossEntropyWithLogitsGrad)},
{string(kNameScatterNdD), ADPT_DESC(ScatterNdD)}, {string(kNameScatterNdD), ADPT_DESC(ScatterNdD)},
{string(kNamePadD), ADPT_DESC(PadD)}, {string(kNamePadD), ADPT_DESC(PadD)},
{string(kNameMirrorPad), ADPT_DESC(MirrorPad)},
{string(kNameMirrorPadGrad), ADPT_DESC(MirrorPadGrad)},
{string(kNameGatherNd), ADPT_DESC(GatherNd)}, {string(kNameGatherNd), ADPT_DESC(GatherNd)},
{string(kNameArgmax), ADPT_DESC(ArgMaxD)}, {string(kNameArgmax), ADPT_DESC(ArgMaxD)},
{string(kNameArgmin), ADPT_DESC(ArgMinD)}, {string(kNameArgmin), ADPT_DESC(ArgMinD)},
@@ -1128,8 +1122,8 @@ void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr
if (desc == nullptr) { if (desc == nullptr) {
MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null."; MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null.";
} else { } else {
(void)std::static_pointer_cast<Data>(op)->update_input_desc_data(*desc);
(void)std::static_pointer_cast<Data>(op)->update_output_desc_out(*desc);
(void)std::static_pointer_cast<Data>(op)->update_input_desc_x(*desc);
(void)std::static_pointer_cast<Data>(op)->update_output_desc_y(*desc);
} }
} }




Loading…
Cancel
Save