Merge pull request !3190 from liuxiao93/BasicLSTMCelltags/v0.6.0-beta
| @@ -201,6 +201,10 @@ const char kNameBatchToSpace[] = "BatchToSpace"; | |||||
| const char kNameAtan2[] = "Atan2"; | const char kNameAtan2[] = "Atan2"; | ||||
| const char kNameApplyRMSProp[] = "ApplyRMSProp"; | const char kNameApplyRMSProp[] = "ApplyRMSProp"; | ||||
| const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp"; | const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp"; | ||||
| const char kNameBasicLSTMCell[] = "BasicLSTMCell"; | |||||
| const char kNameBasicLSTMCellInputGrad[] = "BasicLSTMCellInputGrad"; | |||||
| const char kNameBasicLSTMCellWeightGrad[] = "BasicLSTMCellWeightGrad"; | |||||
| const char kNameBasicLSTMCellCStateGrad[] = "BasicLSTMCellCStateGrad"; | |||||
| const char kNameL2Loss[] = "L2Loss"; | const char kNameL2Loss[] = "L2Loss"; | ||||
| const char kNameCTCLoss[] = "CTCLoss"; | const char kNameCTCLoss[] = "CTCLoss"; | ||||
| const char kNameRange[] = "Range"; | const char kNameRange[] = "Range"; | ||||
| @@ -410,6 +414,10 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {string(kNameAtan2), ADPT_DESC(Atan2)}, | {string(kNameAtan2), ADPT_DESC(Atan2)}, | ||||
| {string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)}, | {string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)}, | ||||
| {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSPropD)}, | {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSPropD)}, | ||||
| {string(kNameBasicLSTMCell), ADPT_DESC(BasicLSTMCell)}, | |||||
| {string(kNameBasicLSTMCellInputGrad), ADPT_DESC(BasicLSTMCellInputGrad)}, | |||||
| {string(kNameBasicLSTMCellWeightGrad), ADPT_DESC(BasicLSTMCellWeightGrad)}, | |||||
| {string(kNameBasicLSTMCellCStateGrad), ADPT_DESC(BasicLSTMCellCStateGrad)}, | |||||
| {string(kNameL2Loss), ADPT_DESC(L2Loss)}, | {string(kNameL2Loss), ADPT_DESC(L2Loss)}, | ||||
| {string(kNameCTCLoss), ADPT_DESC(CTCLoss)}, | {string(kNameCTCLoss), ADPT_DESC(CTCLoss)}, | ||||
| {string(kNameRange), ADPT_DESC(RangeD)}, | {string(kNameRange), ADPT_DESC(RangeD)}, | ||||
| @@ -1292,6 +1292,34 @@ ATTR_MAP(ApplyCenteredRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTra | |||||
| OUTPUT_MAP(ApplyCenteredRMSPropD) = { | OUTPUT_MAP(ApplyCenteredRMSPropD) = { | ||||
| {0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(mg)}, {2, OUTPUT_DESC(ms)}, {3, OUTPUT_DESC(mom)}}; | {0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(mg)}, {2, OUTPUT_DESC(ms)}, {3, OUTPUT_DESC(mom)}}; | ||||
| // BasicLSTMCell | |||||
| INPUT_MAP(BasicLSTMCell) = { | |||||
| {1, INPUT_DESC(x)}, {2, INPUT_DESC(h)}, {3, INPUT_DESC(c)}, {4, INPUT_DESC(w)}, {5, INPUT_DESC(b)}}; | |||||
| ATTR_MAP(BasicLSTMCell) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())}, | |||||
| {"forget_bias", ATTR_DESC(forget_bias, AnyTraits<float>())}, | |||||
| {"state_is_tuple", ATTR_DESC(state_is_tuple, AnyTraits<bool>())}, | |||||
| {"activation", ATTR_DESC(activation, AnyTraits<std::string>())}}; | |||||
| OUTPUT_MAP(BasicLSTMCell) = {{0, OUTPUT_DESC(ct)}, {1, OUTPUT_DESC(ht)}, {2, OUTPUT_DESC(it)}, {3, OUTPUT_DESC(jt)}, | |||||
| {4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {7, OUTPUT_DESC(tanhct)}}; | |||||
| // BasicLSTMCellInputGrad | |||||
| INPUT_MAP(BasicLSTMCellInputGrad) = {{1, INPUT_DESC(dgate)}, {2, INPUT_DESC(w)}}; | |||||
| ATTR_MAP(BasicLSTMCellInputGrad) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())}}; | |||||
| OUTPUT_MAP(BasicLSTMCellInputGrad) = {{0, OUTPUT_DESC(dxt)}, {1, OUTPUT_DESC(dht)}}; | |||||
| // BasicLSTMCellWeightGrad | |||||
| INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(h)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(dgate)}}; | |||||
| ATTR_MAP(BasicLSTMCellWeightGrad) = EMPTY_ATTR_MAP; | |||||
| OUTPUT_MAP(BasicLSTMCellWeightGrad) = {{0, OUTPUT_DESC(dw)}, {1, OUTPUT_DESC(db)}}; | |||||
| // BasicLSTMCellCStateGrad | |||||
| INPUT_MAP(BasicLSTMCellCStateGrad) = {{1, INPUT_DESC(c)}, {2, INPUT_DESC(dht)}, {3, INPUT_DESC(dct)}, | |||||
| {4, INPUT_DESC(it)}, {5, INPUT_DESC(jt)}, {6, INPUT_DESC(ft)}, | |||||
| {7, INPUT_DESC(ot)}, {8, INPUT_DESC(tanhct)}}; | |||||
| ATTR_MAP(BasicLSTMCellCStateGrad) = {{"forget_bias", ATTR_DESC(forget_bias, AnyTraits<float>())}, | |||||
| {"activation", ATTR_DESC(activation, AnyTraits<std::string>())}}; | |||||
| OUTPUT_MAP(BasicLSTMCellCStateGrad) = {{0, OUTPUT_DESC(dgate)}, {1, OUTPUT_DESC(dct_1)}}; | |||||
| // L2Loss | // L2Loss | ||||
| INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}}; | INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}}; | ||||
| ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP; | ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP; | ||||
| @@ -488,6 +488,14 @@ DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) | |||||
| DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) | DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) | ||||
| DECLARE_OP_ADAPTER(ApplyCenteredRMSPropD) | DECLARE_OP_ADAPTER(ApplyCenteredRMSPropD) | ||||
| DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSPropD) | DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSPropD) | ||||
| DECLARE_OP_ADAPTER(BasicLSTMCell) | |||||
| DECLARE_OP_USE_OUTPUT(BasicLSTMCell) | |||||
| DECLARE_OP_ADAPTER(BasicLSTMCellInputGrad) | |||||
| DECLARE_OP_USE_OUTPUT(BasicLSTMCellInputGrad) | |||||
| DECLARE_OP_ADAPTER(BasicLSTMCellWeightGrad) | |||||
| DECLARE_OP_USE_OUTPUT(BasicLSTMCellWeightGrad) | |||||
| DECLARE_OP_ADAPTER(BasicLSTMCellCStateGrad) | |||||
| DECLARE_OP_USE_OUTPUT(BasicLSTMCellCStateGrad) | |||||
| DECLARE_OP_ADAPTER(L2Loss) | DECLARE_OP_ADAPTER(L2Loss) | ||||
| DECLARE_OP_USE_OUTPUT(L2Loss) | DECLARE_OP_USE_OUTPUT(L2Loss) | ||||
| DECLARE_OP_ADAPTER(CTCLoss) | DECLARE_OP_ADAPTER(CTCLoss) | ||||