diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter.cc b/mindspore/ccsrc/transform/graph_ir/op_adapter.cc index b72230e2c7..0bb0995be6 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter.cc @@ -276,12 +276,8 @@ OutHandler OpAdapterImpl::getNormalOutput(const OperatorPtr &op, int index) { } Status OpAdapterImpl::UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, - const TypePtr &type) { + const TypePtr &type, const std::string &format) { MS_EXCEPTION_IF_NULL(type); - std::string format = "NCHW"; - if (op->GetOpType() == kExtractImagePatchesOpName) { - format = "NHWC"; - } auto desc = CreateOutputDesc(dyn_cast(shp), type, format); if (desc == nullptr) { @@ -340,7 +336,7 @@ std::shared_ptr OpAdapterImpl::CreateOutputDesc(const abstract::Sh } Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, - const TypePtr &type) { + const TypePtr &type, const std::string &format) { auto tuple_shp = dyn_cast(shp); MS_EXCEPTION_IF_NULL(tuple_shp); @@ -361,10 +357,7 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac MS_LOG(ERROR) << "output_map is not equal tuple_shape size"; return FAILED; } - std::string format = "NCHW"; - if (op->GetOpType() == kTopKOpName) { - format = "NHWC"; - } + for (size_t i = 0; i < tuple_shp->shape().size(); ++i) { auto tuple_type = dyn_cast(type); MS_EXCEPTION_IF_NULL(tuple_type); @@ -389,7 +382,7 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac return SUCCESS; } -std::shared_ptr OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &node) { +std::shared_ptr OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &node, const std::string &format) { MS_EXCEPTION_IF_NULL(node); TypeId me_type = node->Type()->type_id(); if (kObjectTypeTensorType == me_type) { @@ -405,7 +398,7 @@ std::shared_ptr OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &no shape = shape_ptr->shape(); } - auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW"); + auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, format); if (desc == nullptr) { MS_LOG(ERROR) << "Update output descriptor failed!"; return nullptr; @@ -413,7 +406,7 @@ std::shared_ptr OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &no return desc; } -void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { +void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format) { if (op == nullptr) { MS_LOG(ERROR) << "op is nullptr"; return; @@ -424,19 +417,18 @@ void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNode for (size_t i = 1; i < inputs.size(); ++i) { auto it = input_map_.find(i); if (it != input_map_.end()) { - auto desc = CreateNodeDesc(inputs[i]); + auto desc = CreateNodeDesc(inputs[i], format); if (desc == nullptr) { continue; } - if (op->GetOpType() == kExtractImagePatchesOpName) { - desc->SetFormat(ge::Format::FORMAT_NHWC); - } + it->second.update_input_desc(op, *desc); } } } -void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) { +void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, + const std::string format) { if (op == nullptr) { MS_LOG(ERROR) << "op is nullptr"; return; @@ -452,7 +444,7 @@ void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfN auto inputs = node->cast()->inputs(); for (size_t i = 1; i < inputs.size(); ++i) { if (input_map.find(i) != input_map.end()) { - auto desc = CreateNodeDesc(inputs[i]); + auto desc = CreateNodeDesc(inputs[i], format); if (desc == nullptr) { continue; } @@ -464,11 +456,12 @@ void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfN void OpAdapterImpl::updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(node); + std::string format = GetOpIOFormat(node); if (IsCustomOp(op)) { auto cus_op = std::dynamic_pointer_cast(op); - UpdateCustomOpInputDesc(cus_op, node); + UpdateCustomOpInputDesc(cus_op, node, format); } else { - UpdateNormalOpInputDesc(op, node); + UpdateNormalOpInputDesc(op, node, format); } } @@ -483,13 +476,14 @@ void OpAdapterImpl::updateOutputDesc(const OperatorPtr &op, const abstract::Base auto normal_shape_ptr = dyn_cast(shp); auto no_shape_ptr = dyn_cast(shp); + std::string format = GetOpIOFormat(node); if ((nullptr != normal_shape_ptr) || (nullptr != no_shape_ptr)) { - if (UpdateSingleOutputDesc(op, shp, type) != SUCCESS) { + if (UpdateSingleOutputDesc(op, shp, type, format) != SUCCESS) { return; } } else if (nullptr != dyn_cast(shp)) { - if (UpdateMultiOutputDesc(op, shp, type) != SUCCESS) { + if (UpdateMultiOutputDesc(op, shp, type, format) != SUCCESS) { return; } } else { diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter.h b/mindspore/ccsrc/transform/graph_ir/op_adapter.h index b02fe1886c..37595523b3 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter.h +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter.h @@ -75,14 +75,16 @@ class OpAdapterImpl { OutHandler getOutput(const OperatorPtr &op, int index); OutHandler getCustomOutput(const OperatorPtr &op, int index); OutHandler getNormalOutput(const OperatorPtr &op, int index); - Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type); + Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, + const std::string &format); size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op); std::shared_ptr CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, const std::string &format); - Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type); - std::shared_ptr CreateNodeDesc(const AnfNodePtr &node); - void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node); - void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node); + Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, + const std::string &format); + std::shared_ptr CreateNodeDesc(const AnfNodePtr &node, const std::string &format); + void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format); + void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format); void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node); void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, const AnfNodePtr &node); @@ -226,8 +228,9 @@ class OpAdapter : public BaseOpAdapter { OutHandler getNormalOutput(const OperatorPtr &op, int index) { return impl_->getNormalOutput(op, index); } - Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { - return impl_->UpdateSingleOutputDesc(op, shp, type); + Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, + const std::string &format) { + return impl_->UpdateSingleOutputDesc(op, shp, type, format); } size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { return impl_->GetCustomOpOutputSize(cus_op); } @@ -237,18 +240,21 @@ class OpAdapter : public BaseOpAdapter { return impl_->CreateOutputDesc(shape_ptr, type, format); } - Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { - return impl_->UpdateMultiOutputDesc(op, shp, type); + Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, + const std::string &format) { + return impl_->UpdateMultiOutputDesc(op, shp, type, format); } - std::shared_ptr CreateNodeDesc(const AnfNodePtr &node) { return impl_->CreateNodeDesc(node); } + std::shared_ptr CreateNodeDesc(const AnfNodePtr &node, const std::string &format) { + return impl_->CreateNodeDesc(node, format); + } - void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node) { - return impl_->UpdateNormalOpInputDesc(op, node); + void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node, const std::string format) { + return impl_->UpdateNormalOpInputDesc(op, node, format); } - void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) { - return impl_->UpdateCustomOpInputDesc(op, node); + void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format) { + return impl_->UpdateCustomOpInputDesc(op, node, format); } void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { impl_->updateInputDesc(op, node); } diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc index 78f1f263de..04cb4f3129 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc @@ -247,7 +247,7 @@ bool IsCustomCNode(const AnfNodePtr &anf) { return false; } if (node->inputs().empty()) { - MS_LOG(EXCEPTION) << "length of node inputs is empty"; + MS_LOG(EXCEPTION) << "Length of node inputs is empty"; } MS_EXCEPTION_IF_NULL(node->inputs()[0]); if (!node->inputs()[0]->isa()) { @@ -260,5 +260,37 @@ bool IsCustomCNode(const AnfNodePtr &anf) { return IsCustomPrim(cus_prim); } + +std::string GetOpIOFormat(const AnfNodePtr &anf) { + std::string ret; + if (anf == nullptr) { + MS_LOG(ERROR) << "The anf is nullptr"; + return ret; + } + auto node = anf->cast(); + if (node == nullptr) { + MS_LOG(ERROR) << "The anf is not a cnode."; + return ret; + } + if (node->inputs().empty()) { + MS_LOG(EXCEPTION) << "Length of node inputs is empty."; + } + MS_EXCEPTION_IF_NULL(node->inputs()[0]); + if (!node->inputs()[0]->isa()) { + MS_LOG(ERROR) << "The anf is not a value node."; + return ret; + } + auto prim = GetValueNode(node->inputs()[0]); + if (prim == nullptr) { + MS_LOG(ERROR) << "The anf is not a Primitive."; + return ret; + } + ValuePtr format = prim->GetAttr("io_format"); + if (format == nullptr) { + return "NCHW"; + } + ret = GetValue(format); + return ret; +} } // namespace transform } // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.h index d80aa3b5b3..e43e07b074 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.h +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.h @@ -61,6 +61,7 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits); bool IsCustomPrim(const PrimitivePtr &prim); bool IsCustomCNode(const AnfNodePtr &node); +std::string GetOpIOFormat(const AnfNodePtr &node); } // namespace transform } // namespace mindspore #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.cc index a6275f8df5..1b6c433664 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/rnn_declare.cc @@ -25,7 +25,7 @@ ATTR_MAP(BasicLSTMCell) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits() {"state_is_tuple", ATTR_DESC(state_is_tuple, AnyTraits())}, {"activation", ATTR_DESC(activation, AnyTraits())}}; OUTPUT_MAP(BasicLSTMCell) = {{0, OUTPUT_DESC(ct)}, {1, OUTPUT_DESC(ht)}, {2, OUTPUT_DESC(it)}, {3, OUTPUT_DESC(jt)}, - {4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {7, OUTPUT_DESC(tanhct)}}; + {4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {6, OUTPUT_DESC(tanhct)}}; REG_ADPT_DESC(BasicLSTMCell, kNameBasicLSTMCell, ADPT_DESC(BasicLSTMCell)) // BasicLSTMCellInputGrad @@ -35,7 +35,7 @@ OUTPUT_MAP(BasicLSTMCellInputGrad) = {{0, OUTPUT_DESC(dxt)}, {1, OUTPUT_DESC(dht REG_ADPT_DESC(BasicLSTMCellInputGrad, kNameBasicLSTMCellInputGrad, ADPT_DESC(BasicLSTMCellInputGrad)) // BasicLSTMCellWeightGrad -INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(h)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(dgate)}}; +INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(h)}, {3, INPUT_DESC(dgate)}}; ATTR_MAP(BasicLSTMCellWeightGrad) = EMPTY_ATTR_MAP; OUTPUT_MAP(BasicLSTMCellWeightGrad) = {{0, OUTPUT_DESC(dw)}, {1, OUTPUT_DESC(db)}}; REG_ADPT_DESC(BasicLSTMCellWeightGrad, kNameBasicLSTMCellWeightGrad, ADPT_DESC(BasicLSTMCellWeightGrad)) diff --git a/mindspore/ccsrc/transform/graph_ir/util.cc b/mindspore/ccsrc/transform/graph_ir/util.cc index 4c653b3c80..e00db248e1 100644 --- a/mindspore/ccsrc/transform/graph_ir/util.cc +++ b/mindspore/ccsrc/transform/graph_ir/util.cc @@ -87,7 +87,10 @@ GeFormat TransformUtil::ConvertFormat(const string &format) { return GeFormat::FORMAT_NHWC; } else if (format == kOpFormat_HWCN) { return GeFormat::FORMAT_HWCN; + } else if (format == kOpFormat_ND) { + return GeFormat::FORMAT_ND; } else { + MS_LOG(ERROR) << "Illegal tensor data format: (" << format << "). Use ND format instead."; return GeFormat::FORMAT_ND; } } @@ -113,8 +116,7 @@ std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector(ge_format); - return nullptr; + MS_LOG(INFO) << "Set ND data format"; } // convert me datatype to ge datatype GeDataType data_type = ConvertDataType(me_type); diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index e440e0a0b8..168d76ff55 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1537,6 +1537,7 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer): def __init__(self, forget_bias, activation): self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) self.activation = validator.check_string("activation", activation, ['tanh'], self.name) + self.add_prim_attr("io_format", "ND") def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): # dhy and dcy should be same shape @@ -1586,7 +1587,7 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self): - pass + self.add_prim_attr("io_format", "HWCN") def infer_shape(self, x_shape, h_shape, dgate_shape): validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name) @@ -1595,8 +1596,10 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer): validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name) validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) - dw_shape = (dgate_shape[1], x_shape[1] + h_shape[1], 1, 1) - db_shape = (dgate_shape[1], 1, 1, 1) + input_size = x_shape[1] + hidden_size = h_shape[1] + dw_shape = (input_size + hidden_size, 4 * hidden_size) + db_shape = (4 * hidden_size,) return (dw_shape, db_shape) def infer_dtype(self, x_dtype, h_dtype, dgate_dtype): @@ -1616,13 +1619,17 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer): def __init__(self, keep_prob): self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) + self.add_prim_attr("io_format", "ND") def infer_shape(self, dgate_shape, w_shape): validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name) - validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name) - validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[0]", w_shape[0], Rel.EQ, self.name) - dxt_shape = (dgate_shape[0], w_shape[1] - w_shape[0] // 4) - dht_shape = (dgate_shape[0], dgate_shape[1] // 4) + validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name) + validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) + batch_size = dgate_shape[0] + hidden_size = dgate_shape[1] // 4 + input_size = w_shape[0] - hidden_size + dxt_shape = (batch_size, input_size) + dht_shape = (batch_size, hidden_size) return (dxt_shape, dht_shape) def infer_dtype(self, dgate_dtype, w_dtype): diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 8d90c487e8..02b2319d6d 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -198,6 +198,7 @@ class ExtractImagePatches(PrimitiveWithInfer): _check_tuple_or_list("rate", rates, self.name) self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) self.add_prim_attr("padding", self.padding) + self.add_prim_attr("io_format", "NHWC") def infer_shape(self, input_x): """infer shape""" diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 57ec4dbe82..eb656c73d7 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -5335,35 +5335,41 @@ class BasicLSTMCell(PrimitiveWithInfer): forget_bias (float): Add forget bias to forget gate biases in order to decrease former scale. Default to 1.0. state_is_tuple (bool): If true, state is tensor tuple, containing h and c; If false, one tensor, need split first. Default to True. - activation (str): Activation. Default to "tanh". + activation (str): Activation. Default to "tanh". Only "tanh" is currently supported. Inputs: - **x** (Tensor) - Current words. Tensor of shape (`batch_size`, `input_size`). + The data type must be float16 or float32. - **h** (Tensor) - Hidden state last moment. Tensor of shape (`batch_size`, `hidden_size`). + The data type must be float16 or float32. - **c** (Tensor) - Cell state last moment. Tensor of shape (`batch_size`, `hidden_size`). - - **w** (Tensor) - Weight. Tensor of shape (`4 x hidden_size`, `input_size + hidden_size`, 1, 1). - - **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`, 1, 1, 1). + The data type must be float16 or float32. + - **w** (Tensor) - Weight. Tensor of shape (`input_size + hidden_size`, `4 x hidden_size`). + The data type must be float16 or float32. + - **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`). + The data type must be same as `c`. Outputs: - **ct** (Tensor) - Forward :math:`c_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). - - **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`). + Has the same type with input `c`. + - **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`). With data type of float16. - **it** (Tensor) - Forward :math:`i_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). + Has the same type with input `c`. - **jt** (Tensor) - Forward :math:`j_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). + Has the same type with input `c`. - **ft** (Tensor) - Forward :math:`f_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). + Has the same type with input `c`. - **ot** (Tensor) - Forward :math:`o_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). + Has the same type with input `c`. - **tanhct** (Tensor) - Forward :math:`tanh c_t` cache at moment `t`. - Tensor of shape (`batch_size`, `hidden_size`). + Tensor of shape (`batch_size`, `hidden_size`). Has the same type with input `c`. Examples: - 'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'), - 'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]], - 'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]], - - >>> x = Tensor(np.random.rand(128, 128).astype(np.float16)) - >>> h = Tensor(np.random.rand(128, 128).astype(np.float16)) - >>> c = Tensor(np.random.rand(128, 128).astype(np.float16)) - >>> w = Tensor(np.random.rand(512, 256, 1, 1).astype(np.float16)) - >>> b = Tensor(np.random.rand(512, 1, 1, 1).astype(np.float16)) + >>> x = Tensor(np.random.rand(1, 32).astype(np.float16)) + >>> h = Tensor(np.random.rand(1, 64).astype(np.float16)) + >>> c = Tensor(np.random.rand(1, 64).astype(np.float16)) + >>> w = Tensor(np.random.rand(96, 256).astype(np.float16)) + >>> b = Tensor(np.random.rand(256, ).astype(np.float16)) >>> lstm = P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh') >>> lstm(x, h, c, w, b) """ @@ -5375,42 +5381,38 @@ class BasicLSTMCell(PrimitiveWithInfer): self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) self.activation = validator.check_string("activation", activation, ['tanh'], self.name) + self.add_prim_attr("io_format", "ND") def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape): - # (batch_size, input_size) - validator.check_integer("x_shape", len(x_shape), 2, Rel.EQ, self.name) - - # h and c should be same shape - validator.check_integer("h_shape", len(h_shape), 2, Rel.EQ, self.name) - validator.check("h rank", len(h_shape), "c rank", len(c_shape), Rel.EQ, self.name) - validator.check("h shape", h_shape, "c shape", c_shape, Rel.EQ, self.name) - validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name) - validator.check_integer("b rank", len(b_shape), 4, Rel.EQ, self.name) - validator.check("w_shape[0]", w_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) - validator.check("w_shape[1]", w_shape[1], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name) + validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name) + validator.check_integer("h rank", len(h_shape), 2, Rel.EQ, self.name) + validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name) + validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name) + validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name) + validator.check("x_shape[0]", x_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) + validator.check("c_shape[0]", c_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) + validator.check("c_shape[1]", c_shape[1], "h_shape[1]", h_shape[1], Rel.EQ, self.name) + validator.check("w_shape[1]", w_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) + validator.check("w_shape[0]", w_shape[0], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name) validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) ct_shape = c_shape - ht_shape = h_shape - it_shape = h_shape - jt_shape = h_shape - ft_shape = h_shape - ot_shape = h_shape - tanhct_shape = h_shape + ht_shape = c_shape + it_shape = c_shape + jt_shape = c_shape + ft_shape = c_shape + ot_shape = c_shape + tanhct_shape = c_shape return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape) def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype): - validator.check_subclass("x", x_dtype, [mstype.tensor], self.name) - validator.check_subclass("h", h_dtype, [mstype.tensor], self.name) - validator.check_subclass("c", c_dtype, [mstype.tensor], self.name) - validator.check_subclass("w", w_dtype, [mstype.tensor], self.name) - validator.check_subclass("b", b_dtype, [mstype.tensor], self.name) - validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name) - validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name) - validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name) - validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name) - validator.check_type_name("b", b_dtype, [mstype.float16, mstype.float32], self.name) - return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype) + validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_type_same({"h_dtype": h_dtype}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_type_same({"w_dtype": w_dtype}, [mstype.float16, mstype.float32], self.name) + + args = {"c_dtype": c_dtype, "b_dtype": b_dtype} + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype) class InTopK(PrimitiveWithInfer): diff --git a/tests/ut/cpp/transform/convert_test.cc b/tests/ut/cpp/transform/convert_test.cc index 6902f7d90d..bcdf33c56d 100644 --- a/tests/ut/cpp/transform/convert_test.cc +++ b/tests/ut/cpp/transform/convert_test.cc @@ -735,7 +735,7 @@ TEST_F(TestConvert, TestConvertTensorError) { std::vector dims2{2, 3, 4}; auto type_id_2 = kNumberTypeFloat32; auto me_tensor_ptr_2 = std::make_shared(type_id_2, dims2); - ASSERT_EQ(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr); + ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr); } TEST_F(TestConvert, TestUtilsConvertDataType) { diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 17af407b1f..817f7ff58f 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -701,6 +701,16 @@ class ParallelConcatNet(nn.Cell): return self.parallel_concat((x1, x2)) +class BasicLSTMCellNet(nn.Cell): + """ BasicLSTMCellNet definition """ + + def __init__(self): + super(BasicLSTMCellNet, self).__init__() + self.lstm = P.BasicLSTMCell() + + def construct(self, x, h, c, w, b): + return self.lstm(x, h, c, w, b) + class EditDistance(nn.Cell): def __init__(self, hypothesis_shape, truth_shape, normalize=True): super(EditDistance, self).__init__() @@ -1402,11 +1412,6 @@ test_case_nn_ops = [ 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'skip': ['backward']}), - ('BasicLSTMCell', { - 'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'), - 'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1], [512, 1, 1, 1]], - 'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]], - 'skip': []}), ('TopK', { 'block': P.TopK(), 'desc_const': [5], @@ -2346,6 +2351,18 @@ test_case_other_ops = [ 'block': P.PopulationCount(), 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.int16))], 'skip': ['backward']}), + ('BasicLSTMCellNet', { + 'block': BasicLSTMCellNet(), + 'desc_inputs': [Tensor(np.random.rand(1, 32).astype(np.float16)), + Tensor(np.random.rand(1, 64).astype(np.float16)), + Tensor(np.random.rand(1, 64).astype(np.float16)), + Tensor(np.random.rand(96, 256).astype(np.float16)), + Tensor(np.random.rand(256, ).astype(np.float16))], + 'desc_bprop': [Tensor(np.random.rand(1, 64).astype(np.float16)), + Tensor(np.random.rand(1, 64).astype(np.float16)), + Tensor(np.random.rand(1, 64).astype(np.float16)), + Tensor(np.random.rand(1, 64).astype(np.float16)), + Tensor(np.random.rand(1, 64).astype(np.float16))]}), ] test_case_quant_ops = [