Browse Source

add adapter of Asin, Asinh, Atan etc. operators for graphengine.

tags/v1.2.0-rc1
wangshuide2020 4 years ago
parent
commit
d574de4c42
7 changed files with 259 additions and 0 deletions
  1. +25
    -0
      mindspore/ccsrc/transform/graph_ir/op_adapter_map.h
  2. +109
    -0
      mindspore/ccsrc/transform/graph_ir/op_declare/elewise_calculation_ops_declare.cc
  3. +54
    -0
      mindspore/ccsrc/transform/graph_ir/op_declare/elewise_calculation_ops_declare.h
  4. +44
    -0
      mindspore/ccsrc/transform/graph_ir/op_declare/nn_training_ops_declare.cc
  5. +18
    -0
      mindspore/ccsrc/transform/graph_ir/op_declare/nn_training_ops_declare.h
  6. +6
    -0
      mindspore/ccsrc/transform/graph_ir/op_declare/pad_ops_declare.cc
  7. +3
    -0
      mindspore/ccsrc/transform/graph_ir/op_declare/pad_ops_declare.h

+ 25
- 0
mindspore/ccsrc/transform/graph_ir/op_adapter_map.h View File

@@ -31,6 +31,7 @@ constexpr const char kNameSimpleMean[] = "SimpleMean";
constexpr const char kNameSimpleMeanGrad[] = "SimpleMeanGrad";
constexpr const char kNameAllReduce[] = "AllReduce";
constexpr const char kNameBroadcast[] = "Broadcast";
constexpr const char kNameBroadcastTo[] = "BroadcastTo";
constexpr const char kNameAllgather[] = "AllGather";
constexpr const char kNameReduceScatter[] = "ReduceScatter";
constexpr const char kNameReduceSum[] = "ReduceSum";
@@ -52,6 +53,7 @@ constexpr const char kNameLogicalOr[] = "LogicalOr";
constexpr const char kNameExp[] = "Exp";
constexpr const char kNameLessEqual[] = "LessEqual";
constexpr const char kNameGreaterEqual[] = "GreaterEqual";
constexpr const char kNameApproximateEqual[] = "ApproximateEqual";
constexpr const char kNameEqual[] = "Equal";
constexpr const char kNameNotEqual[] = "NotEqual";
constexpr const char kNameFlattenGrad[] = "FlattenGrad";
@@ -75,6 +77,12 @@ constexpr const char kNameConfusionMatrix[] = "ConfusionMatrix";
constexpr const char kNameResizeNearestNeighborD[] = "ResizeNearestNeighbor";
constexpr const char kNameResizeNearestNeighborGrad[] = "ResizeNearestNeighborGrad";
constexpr const char kNameApplyAdam[] = "Adam";
constexpr const char kNameApplyAdagrad[] = "ApplyAdagrad";
constexpr const char kNameApplyAdadelta[] = "ApplyAdadelta";
constexpr const char kNameApplyAdaMax[] = "ApplyAdaMax";
constexpr const char kNameApplyGradientDescent[] = "ApplyGradientDescent";
constexpr const char kNameApplyPowerSign[] = "ApplyPowerSign";
constexpr const char kNameApplyProximalGradientDescent[] = "ApplyProximalGradientDescent";
constexpr const char kNameExtractImagePatches[] = "ExtractImagePatches";
constexpr const char kNameReLU6[] = "ReLU6";
constexpr const char kNameReLU6Grad[] = "ReLU6Grad";
@@ -116,13 +124,26 @@ constexpr const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus";
constexpr const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus";
constexpr const char kNameReshape[] = "Reshape";
constexpr const char kNameTransShape[] = "TransShape";
constexpr const char kNameDiv[] = "Div";
constexpr const char kNameRealDiv[] = "RealDiv";
constexpr const char kNameBitwiseAnd[] = "BitwiseAnd";
constexpr const char kNameBitwiseOr[] = "BitwiseOr";
constexpr const char kNameBitwiseXor[] = "BitwiseXor";
constexpr const char kNameCeil[] = "Ceil";
constexpr const char kNameCosineEmbeddingLoss[] = "CosineEmbeddingLoss";
constexpr const char kNameXdivy[] = "Xdivy";
constexpr const char kNameTile[] = "Tile";
constexpr const char kNameCos[] = "Cos";
constexpr const char kNameCosh[] = "Cosh";
constexpr const char kNameACos[] = "ACos";
constexpr const char kNameACosGrad[] = "ACosGrad";
constexpr const char kNameFloorDiv[] = "FloorDiv";
constexpr const char kNameSin[] = "Sin";
constexpr const char kNameSinh[] = "Sinh";
constexpr const char kNameAsin[] = "Asin";
constexpr const char kNameAsinGrad[] = "AsinGrad";
constexpr const char kNameAsinh[] = "Asinh";
constexpr const char kNameAsinhGrad[] = "AsinhGrad";
constexpr const char kNamePrelu[] = "PReLU";
constexpr const char kNamePreluGrad[] = "PReLUGrad";
constexpr const char kNameSigmoid[] = "Sigmoid";
@@ -180,6 +201,10 @@ constexpr const char kNameDiag[] = "Diag";
constexpr const char kNameDiagPart[] = "DiagPart";
constexpr const char kNameSpaceToBatch[] = "SpaceToBatch";
constexpr const char kNameBatchToSpace[] = "BatchToSpace";
constexpr const char kNameTan[] = "Tan";
constexpr const char kNameAtan[] = "Atan";
constexpr const char kNameAtanGrad[] = "AtanGrad";
constexpr const char kNameAtanh[] = "Atanh";
constexpr const char kNameAtan2[] = "Atan2";
constexpr const char kNameApplyRMSProp[] = "ApplyRMSProp";
constexpr const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp";


+ 109
- 0
mindspore/ccsrc/transform/graph_ir/op_declare/elewise_calculation_ops_declare.cc View File

@@ -58,6 +58,12 @@ ATTR_MAP(Cos) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Cos) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Cos, kNameCos, ADPT_DESC(Cos))

// Cosh
INPUT_MAP(Cosh) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Cosh) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Cosh) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Cosh, kNameCosh, ADPT_DESC(Cosh))

// Acos
INPUT_MAP(Acos) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Acos) = EMPTY_ATTR_MAP;
@@ -82,6 +88,12 @@ ATTR_MAP(AcoshGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(AcoshGrad) = {{0, OUTPUT_DESC(z)}};
REG_ADPT_DESC(AcoshGrad, kNameAcoshGrad, ADPT_DESC(AcoshGrad))

// Div
INPUT_MAP(Div) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(Div) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Div) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Div, kNameDiv, ADPT_DESC(Div))

// Floor
INPUT_MAP(Floor) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Floor) = EMPTY_ATTR_MAP;
@@ -106,6 +118,73 @@ ATTR_MAP(Sin) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Sin) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Sin, kNameSin, ADPT_DESC(Sin))

// Sinh
INPUT_MAP(Sinh) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Sinh) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Sinh) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Sinh, kNameSinh, ADPT_DESC(Sinh))

// Asin
INPUT_MAP(Asin) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Asin) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Asin) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Asin, kNameAsin, ADPT_DESC(Asin))

// AsinGrad
INPUT_MAP(AsinGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(AsinGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(AsinGrad) = {{0, OUTPUT_DESC(z)}};
REG_ADPT_DESC(AsinGrad, kNameAsinGrad, ADPT_DESC(AsinGrad))

// Asinh
INPUT_MAP(Asinh) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Asinh) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Asinh) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Asinh, kNameAsinh, ADPT_DESC(Asinh))

// AsinhGrad
INPUT_MAP(AsinhGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(AsinhGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(AsinhGrad) = {{0, OUTPUT_DESC(z)}};
REG_ADPT_DESC(AsinhGrad, kNameAsinhGrad, ADPT_DESC(AsinhGrad))

// BitwiseAnd
INPUT_MAP(BitwiseAnd) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(BitwiseAnd) = EMPTY_ATTR_MAP;
OUTPUT_MAP(BitwiseAnd) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(BitwiseAnd, kNameBitwiseAnd, ADPT_DESC(BitwiseAnd))

// BitwiseOr
INPUT_MAP(BitwiseOr) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(BitwiseOr) = EMPTY_ATTR_MAP;
OUTPUT_MAP(BitwiseOr) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(BitwiseOr, kNameBitwiseOr, ADPT_DESC(BitwiseOr))

// BitwiseXor
INPUT_MAP(BitwiseXor) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(BitwiseXor) = EMPTY_ATTR_MAP;
OUTPUT_MAP(BitwiseXor) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(BitwiseXor, kNameBitwiseXor, ADPT_DESC(BitwiseXor))

// Ceil
INPUT_MAP(Ceil) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Ceil) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Ceil) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Ceil, kNameCeil, ADPT_DESC(Ceil))

// CosineEmbeddingLoss
INPUT_MAP(CosineEmbeddingLoss) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(target)}};
ATTR_MAP(CosineEmbeddingLoss) = {{"margin", ATTR_DESC(margin, AnyTraits<float>())},
{"reduction", ATTR_DESC(reduction, AnyTraits<std::string>())}};
OUTPUT_MAP(CosineEmbeddingLoss) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(CosineEmbeddingLoss, kNameCosineEmbeddingLoss, ADPT_DESC(CosineEmbeddingLoss))

// Xdivy
INPUT_MAP(Xdivy) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(Xdivy) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Xdivy) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Xdivy, kNameXdivy, ADPT_DESC(Xdivy))

// Exp
INPUT_MAP(Exp) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Exp) = EMPTY_ATTR_MAP;
@@ -291,6 +370,12 @@ ATTR_MAP(Equal) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Equal) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Equal, kNameEqual, ADPT_DESC(Equal))

// ApproximateEqual
INPUT_MAP(ApproximateEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(ApproximateEqual) = {{"tolerance", ATTR_DESC(tolerance, AnyTraits<float>())}};
OUTPUT_MAP(ApproximateEqual) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ApproximateEqual, kNameApproximateEqual, ADPT_DESC(ApproximateEqual))

// NotEqual
INPUT_MAP(NotEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(NotEqual) = EMPTY_ATTR_MAP;
@@ -357,6 +442,30 @@ ATTR_MAP(Round) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Round) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Round, kNameRound, ADPT_DESC(Round))

// Tan
INPUT_MAP(Tan) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Tan) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Tan) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Tan, kNameTan, ADPT_DESC(Tan))

// Atan
INPUT_MAP(Atan) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Atan) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Atan) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Atan, kNameAtan, ADPT_DESC(Atan))

// AtanGrad
INPUT_MAP(AtanGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(AtanGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(AtanGrad) = {{0, OUTPUT_DESC(z)}};
REG_ADPT_DESC(AtanGrad, kNameAtanGrad, ADPT_DESC(AtanGrad))

// Atanh
INPUT_MAP(Atanh) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Atanh) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Atanh) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Atanh, kNameAtanh, ADPT_DESC(Atanh))

// Atan2
INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
ATTR_MAP(Atan2) = EMPTY_ATTR_MAP;


+ 54
- 0
mindspore/ccsrc/transform/graph_ir/op_declare/elewise_calculation_ops_declare.h View File

@@ -87,6 +87,24 @@ DECLARE_OP_USE_OUTPUT(MinimumGrad)
DECLARE_OP_ADAPTER(RealDiv)
DECLARE_OP_USE_OUTPUT(RealDiv)

DECLARE_OP_ADAPTER(BitwiseAnd)
DECLARE_OP_USE_OUTPUT(BitwiseAnd)

DECLARE_OP_ADAPTER(BitwiseOr)
DECLARE_OP_USE_OUTPUT(BitwiseOr)

DECLARE_OP_ADAPTER(BitwiseXor)
DECLARE_OP_USE_OUTPUT(BitwiseXor)

DECLARE_OP_ADAPTER(Ceil)
DECLARE_OP_USE_OUTPUT(Ceil)

DECLARE_OP_ADAPTER(CosineEmbeddingLoss)
DECLARE_OP_USE_OUTPUT(CosineEmbeddingLoss)

DECLARE_OP_ADAPTER(Xdivy)
DECLARE_OP_USE_OUTPUT(Xdivy)

DECLARE_OP_ADAPTER(Cast)
DECLARE_OP_USE_INPUT_ATTR(Cast)
DECLARE_OP_USE_OUTPUT(Cast)
@@ -106,6 +124,9 @@ DECLARE_OP_USE_OUTPUT(Pow)
DECLARE_OP_ADAPTER(Equal)
DECLARE_OP_USE_OUTPUT(Equal)

DECLARE_OP_ADAPTER(ApproximateEqual)
DECLARE_OP_USE_OUTPUT(ApproximateEqual)

DECLARE_OP_ADAPTER(NotEqual)
DECLARE_OP_USE_OUTPUT(NotEqual)

@@ -133,6 +154,9 @@ DECLARE_OP_USE_OUTPUT(Add)
DECLARE_OP_ADAPTER(Cos)
DECLARE_OP_USE_OUTPUT(Cos)

DECLARE_OP_ADAPTER(Cosh)
DECLARE_OP_USE_OUTPUT(Cosh)

DECLARE_OP_ADAPTER(Acos)
DECLARE_OP_USE_OUTPUT(Acos)

@@ -145,6 +169,9 @@ DECLARE_OP_USE_OUTPUT(Acosh)
DECLARE_OP_ADAPTER(AcoshGrad)
DECLARE_OP_USE_OUTPUT(AcoshGrad)

DECLARE_OP_ADAPTER(Div)
DECLARE_OP_USE_OUTPUT(Div)

DECLARE_OP_ADAPTER(Floor)
DECLARE_OP_USE_OUTPUT(Floor)

@@ -157,6 +184,21 @@ DECLARE_OP_USE_OUTPUT(FloorMod)
DECLARE_OP_ADAPTER(Sin)
DECLARE_OP_USE_OUTPUT(Sin)

DECLARE_OP_ADAPTER(Sinh)
DECLARE_OP_USE_OUTPUT(Sinh)

DECLARE_OP_ADAPTER(Asin)
DECLARE_OP_USE_OUTPUT(Asin)

DECLARE_OP_ADAPTER(AsinGrad)
DECLARE_OP_USE_OUTPUT(AsinGrad)

DECLARE_OP_ADAPTER(Asinh)
DECLARE_OP_USE_OUTPUT(Asinh)

DECLARE_OP_ADAPTER(AsinhGrad)
DECLARE_OP_USE_OUTPUT(AsinhGrad)

DECLARE_OP_ADAPTER(Exp)
DECLARE_OP_USE_OUTPUT(Exp)

@@ -187,6 +229,18 @@ DECLARE_OP_USE_OUTPUT(Sign)
DECLARE_OP_ADAPTER(Round)
DECLARE_OP_USE_OUTPUT(Round)

DECLARE_OP_ADAPTER(Tan)
DECLARE_OP_USE_OUTPUT(Tan)

DECLARE_OP_ADAPTER(Atan)
DECLARE_OP_USE_OUTPUT(Atan)

DECLARE_OP_ADAPTER(AtanGrad)
DECLARE_OP_USE_OUTPUT(AtanGrad)

DECLARE_OP_ADAPTER(Atanh)
DECLARE_OP_USE_OUTPUT(Atanh)

DECLARE_OP_ADAPTER(Atan2)
DECLARE_OP_USE_OUTPUT(Atan2)



+ 44
- 0
mindspore/ccsrc/transform/graph_ir/op_declare/nn_training_ops_declare.cc View File

@@ -61,6 +61,50 @@ REG_ADPT_DESC(ApplyAdamD, kNameApplyAdam, ADPT_DESC(ApplyAdamD))
REG_ADPT_DESC(ApplyAdam, kNameApplyAdam, ADPT_DESC(ApplyAdam))
#endif

// ApplyAdagradD
INPUT_MAP(ApplyAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}};
ATTR_MAP(ApplyAdagradD) = {{"update_slots", ATTR_DESC(update_slots, AnyTraits<bool>())},
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
REG_ADPT_DESC(ApplyAdagradD, kNameApplyAdagrad, ADPT_DESC(ApplyAdagradD))

// ApplyAdadeltaD
INPUT_MAP(ApplyAdadeltaD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(accum_update)},
{4, INPUT_DESC(lr)}, {5, INPUT_DESC(rho)}, {6, INPUT_DESC(epsilon)},
{7, INPUT_DESC(grad)}};
ATTR_MAP(ApplyAdadeltaD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAdadeltaD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(accum_update)}};
REG_ADPT_DESC(ApplyAdadeltaD, kNameApplyAdadelta, ADPT_DESC(ApplyAdadeltaD))

// ApplyAdaMaxD
INPUT_MAP(ApplyAdaMaxD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)},
{4, INPUT_DESC(beta1_power)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(beta1)},
{7, INPUT_DESC(beta2)}, {8, INPUT_DESC(epsilon)}, {9, INPUT_DESC(grad)}};
ATTR_MAP(ApplyAdaMaxD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAdaMaxD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}};
REG_ADPT_DESC(ApplyAdaMaxD, kNameApplyAdaMax, ADPT_DESC(ApplyAdaMaxD))

// ApplyGradientDescent
INPUT_MAP(ApplyGradientDescent) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(alpha)}, {3, INPUT_DESC(delta)}};
ATTR_MAP(ApplyGradientDescent) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyGradientDescent) = {{0, OUTPUT_DESC(var)}};
REG_ADPT_DESC(ApplyGradientDescent, kNameApplyGradientDescent, ADPT_DESC(ApplyGradientDescent))

// ApplyPowerSignD
INPUT_MAP(ApplyPowerSignD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(lr)},
{4, INPUT_DESC(logbase)}, {5, INPUT_DESC(sign_decay)}, {6, INPUT_DESC(beta)},
{7, INPUT_DESC(grad)}};
ATTR_MAP(ApplyPowerSignD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyPowerSignD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}};
REG_ADPT_DESC(ApplyPowerSignD, kNameApplyPowerSign, ADPT_DESC(ApplyPowerSignD))

// ApplyProximalGradientDescent
INPUT_MAP(ApplyProximalGradientDescent) = {
{1, INPUT_DESC(var)}, {2, INPUT_DESC(alpha)}, {3, INPUT_DESC(l1)}, {4, INPUT_DESC(l2)}, {5, INPUT_DESC(delta)}};
ATTR_MAP(ApplyProximalGradientDescent) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyProximalGradientDescent) = {{0, OUTPUT_DESC(var)}};
REG_ADPT_DESC(ApplyProximalGradientDescent, kNameApplyProximalGradientDescent, ADPT_DESC(ApplyProximalGradientDescent))

// SGD
INPUT_MAP(SGD) = {{1, INPUT_DESC(parameters)}, {2, INPUT_DESC(gradient)}, {3, INPUT_DESC(learning_rate)},
{4, INPUT_DESC(accum)}, {5, INPUT_DESC(momentum)}, {6, INPUT_DESC(stat)}};


+ 18
- 0
mindspore/ccsrc/transform/graph_ir/op_declare/nn_training_ops_declare.h View File

@@ -29,6 +29,24 @@ DECLARE_OP_USE_OUTPUT(ApplyAdam)
DECLARE_OP_ADAPTER(ApplyAdamD)
DECLARE_OP_USE_OUTPUT(ApplyAdamD)

DECLARE_OP_ADAPTER(ApplyAdagradD)
DECLARE_OP_USE_OUTPUT(ApplyAdagradD)

DECLARE_OP_ADAPTER(ApplyAdadeltaD)
DECLARE_OP_USE_OUTPUT(ApplyAdadeltaD)

DECLARE_OP_ADAPTER(ApplyAdaMaxD)
DECLARE_OP_USE_OUTPUT(ApplyAdaMaxD)

DECLARE_OP_ADAPTER(ApplyGradientDescent)
DECLARE_OP_USE_OUTPUT(ApplyGradientDescent)

DECLARE_OP_ADAPTER(ApplyPowerSignD)
DECLARE_OP_USE_OUTPUT(ApplyPowerSignD)

DECLARE_OP_ADAPTER(ApplyProximalGradientDescent)
DECLARE_OP_USE_OUTPUT(ApplyProximalGradientDescent)

DECLARE_OP_ADAPTER(SGD)
DECLARE_OP_USE_OUTPUT(SGD)



+ 6
- 0
mindspore/ccsrc/transform/graph_ir/op_declare/pad_ops_declare.cc View File

@@ -24,6 +24,12 @@ ATTR_MAP(PadD) = {{"paddings", ATTR_DESC(paddings, AnyTraits<std::vector<std::ve
OUTPUT_MAP(PadD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(PadD, kNamePadD, ADPT_DESC(PadD))

// BroadcastToD
INPUT_MAP(BroadcastToD) = {{1, INPUT_DESC(x)}};
ATTR_MAP(BroadcastToD) = {{"shape", ATTR_DESC(shape, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())}};
OUTPUT_MAP(BroadcastToD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(BroadcastToD, kNameBroadcastTo, ADPT_DESC(BroadcastToD))

// Diag
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Diag) = EMPTY_ATTR_MAP;


+ 3
- 0
mindspore/ccsrc/transform/graph_ir/op_declare/pad_ops_declare.h View File

@@ -26,6 +26,9 @@ namespace mindspore::transform {
DECLARE_OP_ADAPTER(PadD)
DECLARE_OP_USE_OUTPUT(PadD)

DECLARE_OP_ADAPTER(BroadcastToD)
DECLARE_OP_USE_OUTPUT(BroadcastToD)

DECLARE_OP_ADAPTER(Diag)
DECLARE_OP_USE_OUTPUT(Diag)
} // namespace mindspore::transform


Loading…
Cancel
Save