Browse Source

add apply_momentum

tags/v0.3.0-alpha
changzherui 6 years ago
parent
commit
84c029d1ca
2 changed files with 5 additions and 6 deletions
  1. +1
    -1
      mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
  2. +4
    -5
      mindspore/ops/_op_impl/tbe/batchnorm.py

+ 1
- 1
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc View File

@@ -32,7 +32,7 @@ namespace tbe {
static std::map<string, string> tbe_func_adapter_map = {
{"softmax", "softmax_v2"},
{"log_softmax", "log_softmax_v2"},
{"applymomentum", "applymomentum_d"},
{"apply_momentum", "apply_momentum_d"},
{"re_lu6", "relu6"},
{"re_lu6_grad", "relu6_grad"},
{"re_lu", "relu"},


+ 4
- 5
mindspore/ops/_op_impl/tbe/batchnorm.py View File

@@ -36,19 +36,18 @@ batch_norm_op_info = TBERegOp("BatchNorm") \
.output(2, "batch_variance", False, "required", "all") \
.output(3, "reserve_space_1", False, "optional", "all") \
.output(4, "reserve_space_2", False, "optional", "all") \
.output(5, "reserve_space_3", False, "optional", "all") \
.dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()




Loading…
Cancel
Save