| @@ -276,12 +276,8 @@ OutHandler OpAdapterImpl::getNormalOutput(const OperatorPtr &op, int index) { | |||
| } | |||
| Status OpAdapterImpl::UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, | |||
| const TypePtr &type) { | |||
| const TypePtr &type, const std::string &format) { | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| std::string format = "NCHW"; | |||
| if (op->GetOpType() == kExtractImagePatchesOpName) { | |||
| format = "NHWC"; | |||
| } | |||
| auto desc = CreateOutputDesc(dyn_cast<abstract::Shape>(shp), type, format); | |||
| if (desc == nullptr) { | |||
| @@ -340,7 +336,7 @@ std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateOutputDesc(const abstract::Sh | |||
| } | |||
| Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, | |||
| const TypePtr &type) { | |||
| const TypePtr &type, const std::string &format) { | |||
| auto tuple_shp = dyn_cast<abstract::TupleShape>(shp); | |||
| MS_EXCEPTION_IF_NULL(tuple_shp); | |||
| @@ -361,10 +357,7 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac | |||
| MS_LOG(ERROR) << "output_map is not equal tuple_shape size"; | |||
| return FAILED; | |||
| } | |||
| std::string format = "NCHW"; | |||
| if (op->GetOpType() == kTopKOpName) { | |||
| format = "NHWC"; | |||
| } | |||
| for (size_t i = 0; i < tuple_shp->shape().size(); ++i) { | |||
| auto tuple_type = dyn_cast<Tuple>(type); | |||
| MS_EXCEPTION_IF_NULL(tuple_type); | |||
| @@ -389,7 +382,7 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &node) { | |||
| std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &node, const std::string &format) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| TypeId me_type = node->Type()->type_id(); | |||
| if (kObjectTypeTensorType == me_type) { | |||
| @@ -405,7 +398,7 @@ std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &no | |||
| shape = shape_ptr->shape(); | |||
| } | |||
| auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW"); | |||
| auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, format); | |||
| if (desc == nullptr) { | |||
| MS_LOG(ERROR) << "Update output descriptor failed!"; | |||
| return nullptr; | |||
| @@ -413,7 +406,7 @@ std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &no | |||
| return desc; | |||
| } | |||
| void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { | |||
| void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format) { | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is nullptr"; | |||
| return; | |||
| @@ -424,19 +417,18 @@ void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNode | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| auto it = input_map_.find(i); | |||
| if (it != input_map_.end()) { | |||
| auto desc = CreateNodeDesc(inputs[i]); | |||
| auto desc = CreateNodeDesc(inputs[i], format); | |||
| if (desc == nullptr) { | |||
| continue; | |||
| } | |||
| if (op->GetOpType() == kExtractImagePatchesOpName) { | |||
| desc->SetFormat(ge::Format::FORMAT_NHWC); | |||
| } | |||
| it->second.update_input_desc(op, *desc); | |||
| } | |||
| } | |||
| } | |||
| void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) { | |||
| void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, | |||
| const std::string format) { | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is nullptr"; | |||
| return; | |||
| @@ -452,7 +444,7 @@ void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfN | |||
| auto inputs = node->cast<CNodePtr>()->inputs(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| if (input_map.find(i) != input_map.end()) { | |||
| auto desc = CreateNodeDesc(inputs[i]); | |||
| auto desc = CreateNodeDesc(inputs[i], format); | |||
| if (desc == nullptr) { | |||
| continue; | |||
| } | |||
| @@ -464,11 +456,12 @@ void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfN | |||
| void OpAdapterImpl::updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(op); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| std::string format = GetOpIOFormat(node); | |||
| if (IsCustomOp(op)) { | |||
| auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op); | |||
| UpdateCustomOpInputDesc(cus_op, node); | |||
| UpdateCustomOpInputDesc(cus_op, node, format); | |||
| } else { | |||
| UpdateNormalOpInputDesc(op, node); | |||
| UpdateNormalOpInputDesc(op, node, format); | |||
| } | |||
| } | |||
| @@ -483,13 +476,14 @@ void OpAdapterImpl::updateOutputDesc(const OperatorPtr &op, const abstract::Base | |||
| auto normal_shape_ptr = dyn_cast<abstract::Shape>(shp); | |||
| auto no_shape_ptr = dyn_cast<abstract::NoShape>(shp); | |||
| std::string format = GetOpIOFormat(node); | |||
| if ((nullptr != normal_shape_ptr) || (nullptr != no_shape_ptr)) { | |||
| if (UpdateSingleOutputDesc(op, shp, type) != SUCCESS) { | |||
| if (UpdateSingleOutputDesc(op, shp, type, format) != SUCCESS) { | |||
| return; | |||
| } | |||
| } else if (nullptr != dyn_cast<abstract::TupleShape>(shp)) { | |||
| if (UpdateMultiOutputDesc(op, shp, type) != SUCCESS) { | |||
| if (UpdateMultiOutputDesc(op, shp, type, format) != SUCCESS) { | |||
| return; | |||
| } | |||
| } else { | |||
| @@ -75,14 +75,16 @@ class OpAdapterImpl { | |||
| OutHandler getOutput(const OperatorPtr &op, int index); | |||
| OutHandler getCustomOutput(const OperatorPtr &op, int index); | |||
| OutHandler getNormalOutput(const OperatorPtr &op, int index); | |||
| Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type); | |||
| Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, | |||
| const std::string &format); | |||
| size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op); | |||
| std::shared_ptr<GeTensorDesc> CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, | |||
| const std::string &format); | |||
| Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type); | |||
| std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node); | |||
| void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node); | |||
| void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node); | |||
| Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, | |||
| const std::string &format); | |||
| std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format); | |||
| void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format); | |||
| void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format); | |||
| void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node); | |||
| void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, | |||
| const AnfNodePtr &node); | |||
| @@ -226,8 +228,9 @@ class OpAdapter : public BaseOpAdapter { | |||
| OutHandler getNormalOutput(const OperatorPtr &op, int index) { return impl_->getNormalOutput(op, index); } | |||
| Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { | |||
| return impl_->UpdateSingleOutputDesc(op, shp, type); | |||
| Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, | |||
| const std::string &format) { | |||
| return impl_->UpdateSingleOutputDesc(op, shp, type, format); | |||
| } | |||
| size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { return impl_->GetCustomOpOutputSize(cus_op); } | |||
| @@ -237,18 +240,21 @@ class OpAdapter : public BaseOpAdapter { | |||
| return impl_->CreateOutputDesc(shape_ptr, type, format); | |||
| } | |||
| Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { | |||
| return impl_->UpdateMultiOutputDesc(op, shp, type); | |||
| Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, | |||
| const std::string &format) { | |||
| return impl_->UpdateMultiOutputDesc(op, shp, type, format); | |||
| } | |||
| std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node) { return impl_->CreateNodeDesc(node); } | |||
| std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format) { | |||
| return impl_->CreateNodeDesc(node, format); | |||
| } | |||
| void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node) { | |||
| return impl_->UpdateNormalOpInputDesc(op, node); | |||
| void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node, const std::string format) { | |||
| return impl_->UpdateNormalOpInputDesc(op, node, format); | |||
| } | |||
| void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) { | |||
| return impl_->UpdateCustomOpInputDesc(op, node); | |||
| void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format) { | |||
| return impl_->UpdateCustomOpInputDesc(op, node, format); | |||
| } | |||
| void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { impl_->updateInputDesc(op, node); } | |||
| @@ -247,7 +247,7 @@ bool IsCustomCNode(const AnfNodePtr &anf) { | |||
| return false; | |||
| } | |||
| if (node->inputs().empty()) { | |||
| MS_LOG(EXCEPTION) << "length of node inputs is empty"; | |||
| MS_LOG(EXCEPTION) << "Length of node inputs is empty"; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(node->inputs()[0]); | |||
| if (!node->inputs()[0]->isa<ValueNode>()) { | |||
| @@ -260,5 +260,37 @@ bool IsCustomCNode(const AnfNodePtr &anf) { | |||
| return IsCustomPrim(cus_prim); | |||
| } | |||
| std::string GetOpIOFormat(const AnfNodePtr &anf) { | |||
| std::string ret; | |||
| if (anf == nullptr) { | |||
| MS_LOG(ERROR) << "The anf is nullptr"; | |||
| return ret; | |||
| } | |||
| auto node = anf->cast<CNodePtr>(); | |||
| if (node == nullptr) { | |||
| MS_LOG(ERROR) << "The anf is not a cnode."; | |||
| return ret; | |||
| } | |||
| if (node->inputs().empty()) { | |||
| MS_LOG(EXCEPTION) << "Length of node inputs is empty."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(node->inputs()[0]); | |||
| if (!node->inputs()[0]->isa<ValueNode>()) { | |||
| MS_LOG(ERROR) << "The anf is not a value node."; | |||
| return ret; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(node->inputs()[0]); | |||
| if (prim == nullptr) { | |||
| MS_LOG(ERROR) << "The anf is not a Primitive."; | |||
| return ret; | |||
| } | |||
| ValuePtr format = prim->GetAttr("io_format"); | |||
| if (format == nullptr) { | |||
| return "NCHW"; | |||
| } | |||
| ret = GetValue<std::string>(format); | |||
| return ret; | |||
| } | |||
| } // namespace transform | |||
| } // namespace mindspore | |||
| @@ -61,6 +61,7 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<AnyValue>); | |||
| bool IsCustomPrim(const PrimitivePtr &prim); | |||
| bool IsCustomCNode(const AnfNodePtr &node); | |||
| std::string GetOpIOFormat(const AnfNodePtr &node); | |||
| } // namespace transform | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_ | |||
| @@ -25,7 +25,7 @@ ATTR_MAP(BasicLSTMCell) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>() | |||
| {"state_is_tuple", ATTR_DESC(state_is_tuple, AnyTraits<bool>())}, | |||
| {"activation", ATTR_DESC(activation, AnyTraits<std::string>())}}; | |||
| OUTPUT_MAP(BasicLSTMCell) = {{0, OUTPUT_DESC(ct)}, {1, OUTPUT_DESC(ht)}, {2, OUTPUT_DESC(it)}, {3, OUTPUT_DESC(jt)}, | |||
| {4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {7, OUTPUT_DESC(tanhct)}}; | |||
| {4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {6, OUTPUT_DESC(tanhct)}}; | |||
| REG_ADPT_DESC(BasicLSTMCell, kNameBasicLSTMCell, ADPT_DESC(BasicLSTMCell)) | |||
| // BasicLSTMCellInputGrad | |||
| @@ -35,7 +35,7 @@ OUTPUT_MAP(BasicLSTMCellInputGrad) = {{0, OUTPUT_DESC(dxt)}, {1, OUTPUT_DESC(dht | |||
| REG_ADPT_DESC(BasicLSTMCellInputGrad, kNameBasicLSTMCellInputGrad, ADPT_DESC(BasicLSTMCellInputGrad)) | |||
| // BasicLSTMCellWeightGrad | |||
| INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(h)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(dgate)}}; | |||
| INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(h)}, {3, INPUT_DESC(dgate)}}; | |||
| ATTR_MAP(BasicLSTMCellWeightGrad) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(BasicLSTMCellWeightGrad) = {{0, OUTPUT_DESC(dw)}, {1, OUTPUT_DESC(db)}}; | |||
| REG_ADPT_DESC(BasicLSTMCellWeightGrad, kNameBasicLSTMCellWeightGrad, ADPT_DESC(BasicLSTMCellWeightGrad)) | |||
| @@ -87,7 +87,10 @@ GeFormat TransformUtil::ConvertFormat(const string &format) { | |||
| return GeFormat::FORMAT_NHWC; | |||
| } else if (format == kOpFormat_HWCN) { | |||
| return GeFormat::FORMAT_HWCN; | |||
| } else if (format == kOpFormat_ND) { | |||
| return GeFormat::FORMAT_ND; | |||
| } else { | |||
| MS_LOG(ERROR) << "Illegal tensor data format: (" << format << "). Use ND format instead."; | |||
| return GeFormat::FORMAT_ND; | |||
| } | |||
| } | |||
| @@ -113,8 +116,7 @@ std::shared_ptr<GeTensorDesc> TransformUtil::GetGeTensorDesc(const std::vector<i | |||
| // convert me format to ge format | |||
| GeFormat ge_format = ConvertFormat(format); | |||
| if (ge_format == GeFormat::FORMAT_ND) { | |||
| MS_LOG(ERROR) << "undefined data format : " << static_cast<int>(ge_format); | |||
| return nullptr; | |||
| MS_LOG(INFO) << "Set ND data format"; | |||
| } | |||
| // convert me datatype to ge datatype | |||
| GeDataType data_type = ConvertDataType(me_type); | |||
| @@ -1537,6 +1537,7 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer): | |||
| def __init__(self, forget_bias, activation): | |||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | |||
| self.activation = validator.check_string("activation", activation, ['tanh'], self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): | |||
| # dhy and dcy should be same shape | |||
| @@ -1586,7 +1587,7 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| pass | |||
| self.add_prim_attr("io_format", "HWCN") | |||
| def infer_shape(self, x_shape, h_shape, dgate_shape): | |||
| validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name) | |||
| @@ -1595,8 +1596,10 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer): | |||
| validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name) | |||
| validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) | |||
| validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) | |||
| dw_shape = (dgate_shape[1], x_shape[1] + h_shape[1], 1, 1) | |||
| db_shape = (dgate_shape[1], 1, 1, 1) | |||
| input_size = x_shape[1] | |||
| hidden_size = h_shape[1] | |||
| dw_shape = (input_size + hidden_size, 4 * hidden_size) | |||
| db_shape = (4 * hidden_size,) | |||
| return (dw_shape, db_shape) | |||
| def infer_dtype(self, x_dtype, h_dtype, dgate_dtype): | |||
| @@ -1616,13 +1619,17 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer): | |||
| def __init__(self, keep_prob): | |||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | |||
| self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, dgate_shape, w_shape): | |||
| validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name) | |||
| validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[0]", w_shape[0], Rel.EQ, self.name) | |||
| dxt_shape = (dgate_shape[0], w_shape[1] - w_shape[0] // 4) | |||
| dht_shape = (dgate_shape[0], dgate_shape[1] // 4) | |||
| validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name) | |||
| validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) | |||
| batch_size = dgate_shape[0] | |||
| hidden_size = dgate_shape[1] // 4 | |||
| input_size = w_shape[0] - hidden_size | |||
| dxt_shape = (batch_size, input_size) | |||
| dht_shape = (batch_size, hidden_size) | |||
| return (dxt_shape, dht_shape) | |||
| def infer_dtype(self, dgate_dtype, w_dtype): | |||
| @@ -198,6 +198,7 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||
| _check_tuple_or_list("rate", rates, self.name) | |||
| self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) | |||
| self.add_prim_attr("padding", self.padding) | |||
| self.add_prim_attr("io_format", "NHWC") | |||
| def infer_shape(self, input_x): | |||
| """infer shape""" | |||
| @@ -5335,35 +5335,41 @@ class BasicLSTMCell(PrimitiveWithInfer): | |||
| forget_bias (float): Add forget bias to forget gate biases in order to decrease former scale. Default to 1.0. | |||
| state_is_tuple (bool): If true, state is tensor tuple, containing h and c; If false, one tensor, | |||
| need split first. Default to True. | |||
| activation (str): Activation. Default to "tanh". | |||
| activation (str): Activation. Default to "tanh". Only "tanh" is currently supported. | |||
| Inputs: | |||
| - **x** (Tensor) - Current words. Tensor of shape (`batch_size`, `input_size`). | |||
| The data type must be float16 or float32. | |||
| - **h** (Tensor) - Hidden state last moment. Tensor of shape (`batch_size`, `hidden_size`). | |||
| The data type must be float16 or float32. | |||
| - **c** (Tensor) - Cell state last moment. Tensor of shape (`batch_size`, `hidden_size`). | |||
| - **w** (Tensor) - Weight. Tensor of shape (`4 x hidden_size`, `input_size + hidden_size`, 1, 1). | |||
| - **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`, 1, 1, 1). | |||
| The data type must be float16 or float32. | |||
| - **w** (Tensor) - Weight. Tensor of shape (`input_size + hidden_size`, `4 x hidden_size`). | |||
| The data type must be float16 or float32. | |||
| - **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`). | |||
| The data type must be same as `c`. | |||
| Outputs: | |||
| - **ct** (Tensor) - Forward :math:`c_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | |||
| - **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`). | |||
| Has the same type with input `c`. | |||
| - **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`). With data type of float16. | |||
| - **it** (Tensor) - Forward :math:`i_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | |||
| Has the same type with input `c`. | |||
| - **jt** (Tensor) - Forward :math:`j_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | |||
| Has the same type with input `c`. | |||
| - **ft** (Tensor) - Forward :math:`f_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | |||
| Has the same type with input `c`. | |||
| - **ot** (Tensor) - Forward :math:`o_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | |||
| Has the same type with input `c`. | |||
| - **tanhct** (Tensor) - Forward :math:`tanh c_t` cache at moment `t`. | |||
| Tensor of shape (`batch_size`, `hidden_size`). | |||
| Tensor of shape (`batch_size`, `hidden_size`). Has the same type with input `c`. | |||
| Examples: | |||
| 'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'), | |||
| 'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]], | |||
| 'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]], | |||
| >>> x = Tensor(np.random.rand(128, 128).astype(np.float16)) | |||
| >>> h = Tensor(np.random.rand(128, 128).astype(np.float16)) | |||
| >>> c = Tensor(np.random.rand(128, 128).astype(np.float16)) | |||
| >>> w = Tensor(np.random.rand(512, 256, 1, 1).astype(np.float16)) | |||
| >>> b = Tensor(np.random.rand(512, 1, 1, 1).astype(np.float16)) | |||
| >>> x = Tensor(np.random.rand(1, 32).astype(np.float16)) | |||
| >>> h = Tensor(np.random.rand(1, 64).astype(np.float16)) | |||
| >>> c = Tensor(np.random.rand(1, 64).astype(np.float16)) | |||
| >>> w = Tensor(np.random.rand(96, 256).astype(np.float16)) | |||
| >>> b = Tensor(np.random.rand(256, ).astype(np.float16)) | |||
| >>> lstm = P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh') | |||
| >>> lstm(x, h, c, w, b) | |||
| """ | |||
| @@ -5375,42 +5381,38 @@ class BasicLSTMCell(PrimitiveWithInfer): | |||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | |||
| self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) | |||
| self.activation = validator.check_string("activation", activation, ['tanh'], self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape): | |||
| # (batch_size, input_size) | |||
| validator.check_integer("x_shape", len(x_shape), 2, Rel.EQ, self.name) | |||
| # h and c should be same shape | |||
| validator.check_integer("h_shape", len(h_shape), 2, Rel.EQ, self.name) | |||
| validator.check("h rank", len(h_shape), "c rank", len(c_shape), Rel.EQ, self.name) | |||
| validator.check("h shape", h_shape, "c shape", c_shape, Rel.EQ, self.name) | |||
| validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name) | |||
| validator.check_integer("b rank", len(b_shape), 4, Rel.EQ, self.name) | |||
| validator.check("w_shape[0]", w_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) | |||
| validator.check("w_shape[1]", w_shape[1], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name) | |||
| validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer("h rank", len(h_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name) | |||
| validator.check("x_shape[0]", x_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) | |||
| validator.check("c_shape[0]", c_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) | |||
| validator.check("c_shape[1]", c_shape[1], "h_shape[1]", h_shape[1], Rel.EQ, self.name) | |||
| validator.check("w_shape[1]", w_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) | |||
| validator.check("w_shape[0]", w_shape[0], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name) | |||
| validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) | |||
| ct_shape = c_shape | |||
| ht_shape = h_shape | |||
| it_shape = h_shape | |||
| jt_shape = h_shape | |||
| ft_shape = h_shape | |||
| ot_shape = h_shape | |||
| tanhct_shape = h_shape | |||
| ht_shape = c_shape | |||
| it_shape = c_shape | |||
| jt_shape = c_shape | |||
| ft_shape = c_shape | |||
| ot_shape = c_shape | |||
| tanhct_shape = c_shape | |||
| return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape) | |||
| def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype): | |||
| validator.check_subclass("x", x_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("h", h_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("c", c_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("w", w_dtype, [mstype.tensor], self.name) | |||
| validator.check_subclass("b", b_dtype, [mstype.tensor], self.name) | |||
| validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_type_name("b", b_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype) | |||
| validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_type_same({"h_dtype": h_dtype}, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_type_same({"w_dtype": w_dtype}, [mstype.float16, mstype.float32], self.name) | |||
| args = {"c_dtype": c_dtype, "b_dtype": b_dtype} | |||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||
| return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype) | |||
| class InTopK(PrimitiveWithInfer): | |||
| @@ -735,7 +735,7 @@ TEST_F(TestConvert, TestConvertTensorError) { | |||
| std::vector<int> dims2{2, 3, 4}; | |||
| auto type_id_2 = kNumberTypeFloat32; | |||
| auto me_tensor_ptr_2 = std::make_shared<MeTensor>(type_id_2, dims2); | |||
| ASSERT_EQ(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr); | |||
| ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr); | |||
| } | |||
| TEST_F(TestConvert, TestUtilsConvertDataType) { | |||
| @@ -701,6 +701,16 @@ class ParallelConcatNet(nn.Cell): | |||
| return self.parallel_concat((x1, x2)) | |||
| class BasicLSTMCellNet(nn.Cell): | |||
| """ BasicLSTMCellNet definition """ | |||
| def __init__(self): | |||
| super(BasicLSTMCellNet, self).__init__() | |||
| self.lstm = P.BasicLSTMCell() | |||
| def construct(self, x, h, c, w, b): | |||
| return self.lstm(x, h, c, w, b) | |||
| class EditDistance(nn.Cell): | |||
| def __init__(self, hypothesis_shape, truth_shape, normalize=True): | |||
| super(EditDistance, self).__init__() | |||
| @@ -1402,11 +1412,6 @@ test_case_nn_ops = [ | |||
| 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], | |||
| 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], | |||
| 'skip': ['backward']}), | |||
| ('BasicLSTMCell', { | |||
| 'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'), | |||
| 'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1], [512, 1, 1, 1]], | |||
| 'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]], | |||
| 'skip': []}), | |||
| ('TopK', { | |||
| 'block': P.TopK(), | |||
| 'desc_const': [5], | |||
| @@ -2346,6 +2351,18 @@ test_case_other_ops = [ | |||
| 'block': P.PopulationCount(), | |||
| 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.int16))], | |||
| 'skip': ['backward']}), | |||
| ('BasicLSTMCellNet', { | |||
| 'block': BasicLSTMCellNet(), | |||
| 'desc_inputs': [Tensor(np.random.rand(1, 32).astype(np.float16)), | |||
| Tensor(np.random.rand(1, 64).astype(np.float16)), | |||
| Tensor(np.random.rand(1, 64).astype(np.float16)), | |||
| Tensor(np.random.rand(96, 256).astype(np.float16)), | |||
| Tensor(np.random.rand(256, ).astype(np.float16))], | |||
| 'desc_bprop': [Tensor(np.random.rand(1, 64).astype(np.float16)), | |||
| Tensor(np.random.rand(1, 64).astype(np.float16)), | |||
| Tensor(np.random.rand(1, 64).astype(np.float16)), | |||
| Tensor(np.random.rand(1, 64).astype(np.float16)), | |||
| Tensor(np.random.rand(1, 64).astype(np.float16))]}), | |||
| ] | |||
| test_case_quant_ops = [ | |||