Browse Source

!12757 Change all io_format in master

From: @liangzhibo
Reviewed-by: @kingxian
Signed-off-by: @kingxian
pull/12757/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
dfd368a574
5 changed files with 20 additions and 16 deletions
  1. +16
    -1
      mindspore/ccsrc/transform/graph_ir/io_format_map.cc
  2. +3
    -0
      mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc
  3. +1
    -7
      mindspore/ops/operations/_grad_ops.py
  4. +0
    -3
      mindspore/ops/operations/_inner_ops.py
  5. +0
    -5
      mindspore/ops/operations/nn_ops.py

+ 16
- 1
mindspore/ccsrc/transform/graph_ir/io_format_map.cc View File

@@ -18,7 +18,22 @@

namespace mindspore {
namespace transform {
std::unordered_map<std::string, std::string> IOFormatMap::io_format_map_ = {{"MatMul", "ND"}, {"Conv3D", "format"}};
std::unordered_map<std::string, std::string> IOFormatMap::io_format_map_ = {{"BasicLSTMCell", "ND"},
{"BasicLSTMCellInputGrad", "ND"},
{"BasicLSTMCellCStateGrad", "ND"},
{"Dequant", "ND"},
{"DynamicGRUV2", "ND"},
{"DynamicGRUV2Grad", "ND"},
{"DynamicRNN", "ND"},
{"DynamicRNNGrad", "ND"},
{"MatMul", "ND"},
{"Quant", "ND"},
{"BasicLSTMCellWeightGrad", "HWCN"},
{"ExtractImagePatches", "NCHW"},
{"Conv3D", "format"},
{"Conv3DBackpropFilter", "format"},
{"Conv3DBackpropInput", "format"},
{"Conv3DTranspose", "format"}};
std::unordered_map<std::string, std::string> &IOFormatMap::get() { return io_format_map_; }
} // namespace transform
} // namespace mindspore

+ 3
- 0
mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc View File

@@ -294,6 +294,9 @@ std::string GetOpIOFormat(const AnfNodePtr &anf) {
MS_LOG(ERROR) << "The anf is not a Primitive.";
return ret;
}
if (prim->HasAttr("io_format")) {
return GetValue<std::string>(prim->GetAttr("io_format"));
}
auto io_format_map = IOFormatMap::get();
auto iter = io_format_map.find(prim->name());
if (iter == io_format_map.end()) {


+ 1
- 7
mindspore/ops/operations/_grad_ops.py View File

@@ -393,7 +393,6 @@ class Conv3DBackpropFilter(PrimitiveWithInfer):
self.add_prim_attr('groups', self.group)
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
self.add_prim_attr('data_format', self.format)
self.add_prim_attr('io_format', self.format)

def __infer__(self, x, doutput, w_size):
w_size_v = w_size['value']
@@ -1367,7 +1366,6 @@ class DynamicRNNGrad(PrimitiveWithInfer):
time_major=True,
forget_bias=0.0):
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.add_prim_attr("io_format", "ND")

def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape,
c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape):
@@ -1478,7 +1476,6 @@ class DynamicGRUV2Grad(PrimitiveWithInfer):
self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name)
self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
self.add_prim_attr("io_format", "ND")

def infer_shape(self, x_shape, winput_shape, whidden_shape, y_shape, init_h_shape, h_shape,
dy_shape, dh_shape, update_shape, reset_shape, new_shape, hnew_shape, seq_shape, mask_shape):
@@ -2063,7 +2060,6 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
def __init__(self, forget_bias, activation):
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
self.add_prim_attr("io_format", "ND")

def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
# dhy and dcy should be same shape
@@ -2110,10 +2106,9 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer):

class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
"""Computes the weight gradients of BasicLSTM."""

@prim_attr_register
def __init__(self):
self.add_prim_attr("io_format", "HWCN")
pass

def infer_shape(self, x_shape, h_shape, dgate_shape):
validator.check_equal_int(len(x_shape), 2, "x rank", self.name)
@@ -2145,7 +2140,6 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer):
def __init__(self, keep_prob):
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
self.add_prim_attr("io_format", "ND")

def infer_shape(self, dgate_shape, w_shape):
validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name)


+ 0
- 3
mindspore/ops/operations/_inner_ops.py View File

@@ -74,7 +74,6 @@ class ExtractImagePatches(PrimitiveWithInfer):
_check_tuple_or_list("rate", rates, self.name)
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
self.add_prim_attr("padding", self.padding)
self.add_prim_attr("io_format", "NCHW")
self.is_ge = context.get_context("enable_ge")

def infer_shape(self, input_x):
@@ -213,7 +212,6 @@ class Quant(PrimitiveWithInfer):
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
"round_mode", self.name)
self.add_prim_attr("io_format", "ND")

def infer_shape(self, x_shape):
return x_shape
@@ -265,7 +263,6 @@ class Dequant(PrimitiveWithInfer):
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name)
self.add_prim_attr("dtype", mstype.float16)
self.add_prim_attr("io_format", "ND")

def infer_shape(self, x_shape, deq_scale_shape):
return x_shape


+ 0
- 5
mindspore/ops/operations/nn_ops.py View File

@@ -7222,7 +7222,6 @@ class BasicLSTMCell(PrimitiveWithInfer):
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
self.add_prim_attr("io_format", "ND")

def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape):
validator.check_int(len(x_shape), 2, Rel.EQ, "x rank", self.name)
@@ -7373,7 +7372,6 @@ class DynamicRNN(PrimitiveWithInfer):
self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name)
self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
self.add_prim_attr("io_format", "ND")

def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
validator.check_int(len(x_shape), 3, Rel.EQ, "x_shape", self.name)
@@ -7533,7 +7531,6 @@ class DynamicGRUV2(PrimitiveWithInfer):
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name)
self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
self.add_prim_attr("io_format", "ND")

def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape):
validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name)
@@ -8024,7 +8021,6 @@ class Conv3DBackpropInput(PrimitiveWithInfer):
self.add_prim_attr('groups', self.group)
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
self.add_prim_attr('data_format', self.format)
self.add_prim_attr('io_format', self.format)

def __infer__(self, w, doutput, x_size):
validator.check_equal_int(len(w['shape']), 5, 'The dimension of weight ', self.name)
@@ -8202,7 +8198,6 @@ class Conv3DTranspose(PrimitiveWithInfer):
self.add_prim_attr('groups', self.group)
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
self.add_prim_attr('data_format', self.format)
self.add_prim_attr('io_format', self.format)

self.output_padding = _check_3d_int_or_tuple('output_padding', output_padding, self.name,
allow_five=True, ret_five=True, greater_zero=False)


Loading…
Cancel
Save