| @@ -276,12 +276,8 @@ OutHandler OpAdapterImpl::getNormalOutput(const OperatorPtr &op, int index) { | |||||
| } | } | ||||
| Status OpAdapterImpl::UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, | Status OpAdapterImpl::UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, | ||||
| const TypePtr &type) { | |||||
| const TypePtr &type, const std::string &format) { | |||||
| MS_EXCEPTION_IF_NULL(type); | MS_EXCEPTION_IF_NULL(type); | ||||
| std::string format = "NCHW"; | |||||
| if (op->GetOpType() == kExtractImagePatchesOpName) { | |||||
| format = "NHWC"; | |||||
| } | |||||
| auto desc = CreateOutputDesc(dyn_cast<abstract::Shape>(shp), type, format); | auto desc = CreateOutputDesc(dyn_cast<abstract::Shape>(shp), type, format); | ||||
| if (desc == nullptr) { | if (desc == nullptr) { | ||||
| @@ -340,7 +336,7 @@ std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateOutputDesc(const abstract::Sh | |||||
| } | } | ||||
| Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, | Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, | ||||
| const TypePtr &type) { | |||||
| const TypePtr &type, const std::string &format) { | |||||
| auto tuple_shp = dyn_cast<abstract::TupleShape>(shp); | auto tuple_shp = dyn_cast<abstract::TupleShape>(shp); | ||||
| MS_EXCEPTION_IF_NULL(tuple_shp); | MS_EXCEPTION_IF_NULL(tuple_shp); | ||||
| @@ -361,10 +357,7 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac | |||||
| MS_LOG(ERROR) << "output_map is not equal tuple_shape size"; | MS_LOG(ERROR) << "output_map is not equal tuple_shape size"; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| std::string format = "NCHW"; | |||||
| if (op->GetOpType() == kTopKOpName) { | |||||
| format = "NHWC"; | |||||
| } | |||||
| for (size_t i = 0; i < tuple_shp->shape().size(); ++i) { | for (size_t i = 0; i < tuple_shp->shape().size(); ++i) { | ||||
| auto tuple_type = dyn_cast<Tuple>(type); | auto tuple_type = dyn_cast<Tuple>(type); | ||||
| MS_EXCEPTION_IF_NULL(tuple_type); | MS_EXCEPTION_IF_NULL(tuple_type); | ||||
| @@ -389,7 +382,7 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &node) { | |||||
| std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &node, const std::string &format) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| TypeId me_type = node->Type()->type_id(); | TypeId me_type = node->Type()->type_id(); | ||||
| if (kObjectTypeTensorType == me_type) { | if (kObjectTypeTensorType == me_type) { | ||||
| @@ -405,7 +398,7 @@ std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &no | |||||
| shape = shape_ptr->shape(); | shape = shape_ptr->shape(); | ||||
| } | } | ||||
| auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW"); | |||||
| auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, format); | |||||
| if (desc == nullptr) { | if (desc == nullptr) { | ||||
| MS_LOG(ERROR) << "Update output descriptor failed!"; | MS_LOG(ERROR) << "Update output descriptor failed!"; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -413,7 +406,7 @@ std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &no | |||||
| return desc; | return desc; | ||||
| } | } | ||||
| void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { | |||||
| void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is nullptr"; | MS_LOG(ERROR) << "op is nullptr"; | ||||
| return; | return; | ||||
| @@ -424,19 +417,18 @@ void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNode | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | for (size_t i = 1; i < inputs.size(); ++i) { | ||||
| auto it = input_map_.find(i); | auto it = input_map_.find(i); | ||||
| if (it != input_map_.end()) { | if (it != input_map_.end()) { | ||||
| auto desc = CreateNodeDesc(inputs[i]); | |||||
| auto desc = CreateNodeDesc(inputs[i], format); | |||||
| if (desc == nullptr) { | if (desc == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (op->GetOpType() == kExtractImagePatchesOpName) { | |||||
| desc->SetFormat(ge::Format::FORMAT_NHWC); | |||||
| } | |||||
| it->second.update_input_desc(op, *desc); | it->second.update_input_desc(op, *desc); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) { | |||||
| void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, | |||||
| const std::string format) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is nullptr"; | MS_LOG(ERROR) << "op is nullptr"; | ||||
| return; | return; | ||||
| @@ -452,7 +444,7 @@ void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfN | |||||
| auto inputs = node->cast<CNodePtr>()->inputs(); | auto inputs = node->cast<CNodePtr>()->inputs(); | ||||
| for (size_t i = 1; i < inputs.size(); ++i) { | for (size_t i = 1; i < inputs.size(); ++i) { | ||||
| if (input_map.find(i) != input_map.end()) { | if (input_map.find(i) != input_map.end()) { | ||||
| auto desc = CreateNodeDesc(inputs[i]); | |||||
| auto desc = CreateNodeDesc(inputs[i], format); | |||||
| if (desc == nullptr) { | if (desc == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -464,11 +456,12 @@ void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfN | |||||
| void OpAdapterImpl::updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { | void OpAdapterImpl::updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(op); | MS_EXCEPTION_IF_NULL(op); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| std::string format = GetOpIOFormat(node); | |||||
| if (IsCustomOp(op)) { | if (IsCustomOp(op)) { | ||||
| auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op); | auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op); | ||||
| UpdateCustomOpInputDesc(cus_op, node); | |||||
| UpdateCustomOpInputDesc(cus_op, node, format); | |||||
| } else { | } else { | ||||
| UpdateNormalOpInputDesc(op, node); | |||||
| UpdateNormalOpInputDesc(op, node, format); | |||||
| } | } | ||||
| } | } | ||||
| @@ -483,13 +476,14 @@ void OpAdapterImpl::updateOutputDesc(const OperatorPtr &op, const abstract::Base | |||||
| auto normal_shape_ptr = dyn_cast<abstract::Shape>(shp); | auto normal_shape_ptr = dyn_cast<abstract::Shape>(shp); | ||||
| auto no_shape_ptr = dyn_cast<abstract::NoShape>(shp); | auto no_shape_ptr = dyn_cast<abstract::NoShape>(shp); | ||||
| std::string format = GetOpIOFormat(node); | |||||
| if ((nullptr != normal_shape_ptr) || (nullptr != no_shape_ptr)) { | if ((nullptr != normal_shape_ptr) || (nullptr != no_shape_ptr)) { | ||||
| if (UpdateSingleOutputDesc(op, shp, type) != SUCCESS) { | |||||
| if (UpdateSingleOutputDesc(op, shp, type, format) != SUCCESS) { | |||||
| return; | return; | ||||
| } | } | ||||
| } else if (nullptr != dyn_cast<abstract::TupleShape>(shp)) { | } else if (nullptr != dyn_cast<abstract::TupleShape>(shp)) { | ||||
| if (UpdateMultiOutputDesc(op, shp, type) != SUCCESS) { | |||||
| if (UpdateMultiOutputDesc(op, shp, type, format) != SUCCESS) { | |||||
| return; | return; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -75,14 +75,16 @@ class OpAdapterImpl { | |||||
| OutHandler getOutput(const OperatorPtr &op, int index); | OutHandler getOutput(const OperatorPtr &op, int index); | ||||
| OutHandler getCustomOutput(const OperatorPtr &op, int index); | OutHandler getCustomOutput(const OperatorPtr &op, int index); | ||||
| OutHandler getNormalOutput(const OperatorPtr &op, int index); | OutHandler getNormalOutput(const OperatorPtr &op, int index); | ||||
| Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type); | |||||
| Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, | |||||
| const std::string &format); | |||||
| size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op); | size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op); | ||||
| std::shared_ptr<GeTensorDesc> CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, | std::shared_ptr<GeTensorDesc> CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, | ||||
| const std::string &format); | const std::string &format); | ||||
| Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type); | |||||
| std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node); | |||||
| void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node); | |||||
| void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node); | |||||
| Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, | |||||
| const std::string &format); | |||||
| std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format); | |||||
| void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format); | |||||
| void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format); | |||||
| void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node); | void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node); | ||||
| void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, | void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, | ||||
| const AnfNodePtr &node); | const AnfNodePtr &node); | ||||
| @@ -226,8 +228,9 @@ class OpAdapter : public BaseOpAdapter { | |||||
| OutHandler getNormalOutput(const OperatorPtr &op, int index) { return impl_->getNormalOutput(op, index); } | OutHandler getNormalOutput(const OperatorPtr &op, int index) { return impl_->getNormalOutput(op, index); } | ||||
| Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { | |||||
| return impl_->UpdateSingleOutputDesc(op, shp, type); | |||||
| Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, | |||||
| const std::string &format) { | |||||
| return impl_->UpdateSingleOutputDesc(op, shp, type, format); | |||||
| } | } | ||||
| size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { return impl_->GetCustomOpOutputSize(cus_op); } | size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { return impl_->GetCustomOpOutputSize(cus_op); } | ||||
| @@ -237,18 +240,21 @@ class OpAdapter : public BaseOpAdapter { | |||||
| return impl_->CreateOutputDesc(shape_ptr, type, format); | return impl_->CreateOutputDesc(shape_ptr, type, format); | ||||
| } | } | ||||
| Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { | |||||
| return impl_->UpdateMultiOutputDesc(op, shp, type); | |||||
| Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, | |||||
| const std::string &format) { | |||||
| return impl_->UpdateMultiOutputDesc(op, shp, type, format); | |||||
| } | } | ||||
| std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node) { return impl_->CreateNodeDesc(node); } | |||||
| std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format) { | |||||
| return impl_->CreateNodeDesc(node, format); | |||||
| } | |||||
| void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node) { | |||||
| return impl_->UpdateNormalOpInputDesc(op, node); | |||||
| void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node, const std::string format) { | |||||
| return impl_->UpdateNormalOpInputDesc(op, node, format); | |||||
| } | } | ||||
| void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) { | |||||
| return impl_->UpdateCustomOpInputDesc(op, node); | |||||
| void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format) { | |||||
| return impl_->UpdateCustomOpInputDesc(op, node, format); | |||||
| } | } | ||||
| void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { impl_->updateInputDesc(op, node); } | void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { impl_->updateInputDesc(op, node); } | ||||
| @@ -247,7 +247,7 @@ bool IsCustomCNode(const AnfNodePtr &anf) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (node->inputs().empty()) { | if (node->inputs().empty()) { | ||||
| MS_LOG(EXCEPTION) << "length of node inputs is empty"; | |||||
| MS_LOG(EXCEPTION) << "Length of node inputs is empty"; | |||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(node->inputs()[0]); | MS_EXCEPTION_IF_NULL(node->inputs()[0]); | ||||
| if (!node->inputs()[0]->isa<ValueNode>()) { | if (!node->inputs()[0]->isa<ValueNode>()) { | ||||
| @@ -260,5 +260,37 @@ bool IsCustomCNode(const AnfNodePtr &anf) { | |||||
| return IsCustomPrim(cus_prim); | return IsCustomPrim(cus_prim); | ||||
| } | } | ||||
| std::string GetOpIOFormat(const AnfNodePtr &anf) { | |||||
| std::string ret; | |||||
| if (anf == nullptr) { | |||||
| MS_LOG(ERROR) << "The anf is nullptr"; | |||||
| return ret; | |||||
| } | |||||
| auto node = anf->cast<CNodePtr>(); | |||||
| if (node == nullptr) { | |||||
| MS_LOG(ERROR) << "The anf is not a cnode."; | |||||
| return ret; | |||||
| } | |||||
| if (node->inputs().empty()) { | |||||
| MS_LOG(EXCEPTION) << "Length of node inputs is empty."; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(node->inputs()[0]); | |||||
| if (!node->inputs()[0]->isa<ValueNode>()) { | |||||
| MS_LOG(ERROR) << "The anf is not a value node."; | |||||
| return ret; | |||||
| } | |||||
| auto prim = GetValueNode<PrimitivePtr>(node->inputs()[0]); | |||||
| if (prim == nullptr) { | |||||
| MS_LOG(ERROR) << "The anf is not a Primitive."; | |||||
| return ret; | |||||
| } | |||||
| ValuePtr format = prim->GetAttr("io_format"); | |||||
| if (format == nullptr) { | |||||
| return "NCHW"; | |||||
| } | |||||
| ret = GetValue<std::string>(format); | |||||
| return ret; | |||||
| } | |||||
| } // namespace transform | } // namespace transform | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -61,6 +61,7 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<AnyValue>); | |||||
| bool IsCustomPrim(const PrimitivePtr &prim); | bool IsCustomPrim(const PrimitivePtr &prim); | ||||
| bool IsCustomCNode(const AnfNodePtr &node); | bool IsCustomCNode(const AnfNodePtr &node); | ||||
| std::string GetOpIOFormat(const AnfNodePtr &node); | |||||
| } // namespace transform | } // namespace transform | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_ | #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_ | ||||
| @@ -25,7 +25,7 @@ ATTR_MAP(BasicLSTMCell) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>() | |||||
| {"state_is_tuple", ATTR_DESC(state_is_tuple, AnyTraits<bool>())}, | {"state_is_tuple", ATTR_DESC(state_is_tuple, AnyTraits<bool>())}, | ||||
| {"activation", ATTR_DESC(activation, AnyTraits<std::string>())}}; | {"activation", ATTR_DESC(activation, AnyTraits<std::string>())}}; | ||||
| OUTPUT_MAP(BasicLSTMCell) = {{0, OUTPUT_DESC(ct)}, {1, OUTPUT_DESC(ht)}, {2, OUTPUT_DESC(it)}, {3, OUTPUT_DESC(jt)}, | OUTPUT_MAP(BasicLSTMCell) = {{0, OUTPUT_DESC(ct)}, {1, OUTPUT_DESC(ht)}, {2, OUTPUT_DESC(it)}, {3, OUTPUT_DESC(jt)}, | ||||
| {4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {7, OUTPUT_DESC(tanhct)}}; | |||||
| {4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {6, OUTPUT_DESC(tanhct)}}; | |||||
| REG_ADPT_DESC(BasicLSTMCell, kNameBasicLSTMCell, ADPT_DESC(BasicLSTMCell)) | REG_ADPT_DESC(BasicLSTMCell, kNameBasicLSTMCell, ADPT_DESC(BasicLSTMCell)) | ||||
| // BasicLSTMCellInputGrad | // BasicLSTMCellInputGrad | ||||
| @@ -35,7 +35,7 @@ OUTPUT_MAP(BasicLSTMCellInputGrad) = {{0, OUTPUT_DESC(dxt)}, {1, OUTPUT_DESC(dht | |||||
| REG_ADPT_DESC(BasicLSTMCellInputGrad, kNameBasicLSTMCellInputGrad, ADPT_DESC(BasicLSTMCellInputGrad)) | REG_ADPT_DESC(BasicLSTMCellInputGrad, kNameBasicLSTMCellInputGrad, ADPT_DESC(BasicLSTMCellInputGrad)) | ||||
| // BasicLSTMCellWeightGrad | // BasicLSTMCellWeightGrad | ||||
| INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(h)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(dgate)}}; | |||||
| INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(h)}, {3, INPUT_DESC(dgate)}}; | |||||
| ATTR_MAP(BasicLSTMCellWeightGrad) = EMPTY_ATTR_MAP; | ATTR_MAP(BasicLSTMCellWeightGrad) = EMPTY_ATTR_MAP; | ||||
| OUTPUT_MAP(BasicLSTMCellWeightGrad) = {{0, OUTPUT_DESC(dw)}, {1, OUTPUT_DESC(db)}}; | OUTPUT_MAP(BasicLSTMCellWeightGrad) = {{0, OUTPUT_DESC(dw)}, {1, OUTPUT_DESC(db)}}; | ||||
| REG_ADPT_DESC(BasicLSTMCellWeightGrad, kNameBasicLSTMCellWeightGrad, ADPT_DESC(BasicLSTMCellWeightGrad)) | REG_ADPT_DESC(BasicLSTMCellWeightGrad, kNameBasicLSTMCellWeightGrad, ADPT_DESC(BasicLSTMCellWeightGrad)) | ||||
| @@ -87,7 +87,10 @@ GeFormat TransformUtil::ConvertFormat(const string &format) { | |||||
| return GeFormat::FORMAT_NHWC; | return GeFormat::FORMAT_NHWC; | ||||
| } else if (format == kOpFormat_HWCN) { | } else if (format == kOpFormat_HWCN) { | ||||
| return GeFormat::FORMAT_HWCN; | return GeFormat::FORMAT_HWCN; | ||||
| } else if (format == kOpFormat_ND) { | |||||
| return GeFormat::FORMAT_ND; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Illegal tensor data format: (" << format << "). Use ND format instead."; | |||||
| return GeFormat::FORMAT_ND; | return GeFormat::FORMAT_ND; | ||||
| } | } | ||||
| } | } | ||||
| @@ -113,8 +116,7 @@ std::shared_ptr<GeTensorDesc> TransformUtil::GetGeTensorDesc(const std::vector<i | |||||
| // convert me format to ge format | // convert me format to ge format | ||||
| GeFormat ge_format = ConvertFormat(format); | GeFormat ge_format = ConvertFormat(format); | ||||
| if (ge_format == GeFormat::FORMAT_ND) { | if (ge_format == GeFormat::FORMAT_ND) { | ||||
| MS_LOG(ERROR) << "undefined data format : " << static_cast<int>(ge_format); | |||||
| return nullptr; | |||||
| MS_LOG(INFO) << "Set ND data format"; | |||||
| } | } | ||||
| // convert me datatype to ge datatype | // convert me datatype to ge datatype | ||||
| GeDataType data_type = ConvertDataType(me_type); | GeDataType data_type = ConvertDataType(me_type); | ||||
| @@ -1537,6 +1537,7 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer): | |||||
| def __init__(self, forget_bias, activation): | def __init__(self, forget_bias, activation): | ||||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | ||||
| self.activation = validator.check_string("activation", activation, ['tanh'], self.name) | self.activation = validator.check_string("activation", activation, ['tanh'], self.name) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): | def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): | ||||
| # dhy and dcy should be same shape | # dhy and dcy should be same shape | ||||
| @@ -1586,7 +1587,7 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| pass | |||||
| self.add_prim_attr("io_format", "HWCN") | |||||
| def infer_shape(self, x_shape, h_shape, dgate_shape): | def infer_shape(self, x_shape, h_shape, dgate_shape): | ||||
| validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name) | validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name) | ||||
| @@ -1595,8 +1596,10 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer): | |||||
| validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name) | validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name) | ||||
| validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) | validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) | ||||
| validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) | validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) | ||||
| dw_shape = (dgate_shape[1], x_shape[1] + h_shape[1], 1, 1) | |||||
| db_shape = (dgate_shape[1], 1, 1, 1) | |||||
| input_size = x_shape[1] | |||||
| hidden_size = h_shape[1] | |||||
| dw_shape = (input_size + hidden_size, 4 * hidden_size) | |||||
| db_shape = (4 * hidden_size,) | |||||
| return (dw_shape, db_shape) | return (dw_shape, db_shape) | ||||
| def infer_dtype(self, x_dtype, h_dtype, dgate_dtype): | def infer_dtype(self, x_dtype, h_dtype, dgate_dtype): | ||||
| @@ -1616,13 +1619,17 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer): | |||||
| def __init__(self, keep_prob): | def __init__(self, keep_prob): | ||||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | ||||
| self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) | self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def infer_shape(self, dgate_shape, w_shape): | def infer_shape(self, dgate_shape, w_shape): | ||||
| validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name) | validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name) | ||||
| validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name) | |||||
| validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[0]", w_shape[0], Rel.EQ, self.name) | |||||
| dxt_shape = (dgate_shape[0], w_shape[1] - w_shape[0] // 4) | |||||
| dht_shape = (dgate_shape[0], dgate_shape[1] // 4) | |||||
| validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name) | |||||
| validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) | |||||
| batch_size = dgate_shape[0] | |||||
| hidden_size = dgate_shape[1] // 4 | |||||
| input_size = w_shape[0] - hidden_size | |||||
| dxt_shape = (batch_size, input_size) | |||||
| dht_shape = (batch_size, hidden_size) | |||||
| return (dxt_shape, dht_shape) | return (dxt_shape, dht_shape) | ||||
| def infer_dtype(self, dgate_dtype, w_dtype): | def infer_dtype(self, dgate_dtype, w_dtype): | ||||
| @@ -198,6 +198,7 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||||
| _check_tuple_or_list("rate", rates, self.name) | _check_tuple_or_list("rate", rates, self.name) | ||||
| self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) | self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) | ||||
| self.add_prim_attr("padding", self.padding) | self.add_prim_attr("padding", self.padding) | ||||
| self.add_prim_attr("io_format", "NHWC") | |||||
| def infer_shape(self, input_x): | def infer_shape(self, input_x): | ||||
| """infer shape""" | """infer shape""" | ||||
| @@ -5335,35 +5335,41 @@ class BasicLSTMCell(PrimitiveWithInfer): | |||||
| forget_bias (float): Add forget bias to forget gate biases in order to decrease former scale. Default to 1.0. | forget_bias (float): Add forget bias to forget gate biases in order to decrease former scale. Default to 1.0. | ||||
| state_is_tuple (bool): If true, state is tensor tuple, containing h and c; If false, one tensor, | state_is_tuple (bool): If true, state is tensor tuple, containing h and c; If false, one tensor, | ||||
| need split first. Default to True. | need split first. Default to True. | ||||
| activation (str): Activation. Default to "tanh". | |||||
| activation (str): Activation. Default to "tanh". Only "tanh" is currently supported. | |||||
| Inputs: | Inputs: | ||||
| - **x** (Tensor) - Current words. Tensor of shape (`batch_size`, `input_size`). | - **x** (Tensor) - Current words. Tensor of shape (`batch_size`, `input_size`). | ||||
| The data type must be float16 or float32. | |||||
| - **h** (Tensor) - Hidden state last moment. Tensor of shape (`batch_size`, `hidden_size`). | - **h** (Tensor) - Hidden state last moment. Tensor of shape (`batch_size`, `hidden_size`). | ||||
| The data type must be float16 or float32. | |||||
| - **c** (Tensor) - Cell state last moment. Tensor of shape (`batch_size`, `hidden_size`). | - **c** (Tensor) - Cell state last moment. Tensor of shape (`batch_size`, `hidden_size`). | ||||
| - **w** (Tensor) - Weight. Tensor of shape (`4 x hidden_size`, `input_size + hidden_size`, 1, 1). | |||||
| - **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`, 1, 1, 1). | |||||
| The data type must be float16 or float32. | |||||
| - **w** (Tensor) - Weight. Tensor of shape (`input_size + hidden_size`, `4 x hidden_size`). | |||||
| The data type must be float16 or float32. | |||||
| - **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`). | |||||
| The data type must be same as `c`. | |||||
| Outputs: | Outputs: | ||||
| - **ct** (Tensor) - Forward :math:`c_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | - **ct** (Tensor) - Forward :math:`c_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | ||||
| - **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`). | |||||
| Has the same type with input `c`. | |||||
| - **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`). With data type of float16. | |||||
| - **it** (Tensor) - Forward :math:`i_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | - **it** (Tensor) - Forward :math:`i_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | ||||
| Has the same type with input `c`. | |||||
| - **jt** (Tensor) - Forward :math:`j_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | - **jt** (Tensor) - Forward :math:`j_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | ||||
| Has the same type with input `c`. | |||||
| - **ft** (Tensor) - Forward :math:`f_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | - **ft** (Tensor) - Forward :math:`f_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | ||||
| Has the same type with input `c`. | |||||
| - **ot** (Tensor) - Forward :math:`o_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | - **ot** (Tensor) - Forward :math:`o_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`). | ||||
| Has the same type with input `c`. | |||||
| - **tanhct** (Tensor) - Forward :math:`tanh c_t` cache at moment `t`. | - **tanhct** (Tensor) - Forward :math:`tanh c_t` cache at moment `t`. | ||||
| Tensor of shape (`batch_size`, `hidden_size`). | |||||
| Tensor of shape (`batch_size`, `hidden_size`). Has the same type with input `c`. | |||||
| Examples: | Examples: | ||||
| 'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'), | |||||
| 'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]], | |||||
| 'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]], | |||||
| >>> x = Tensor(np.random.rand(128, 128).astype(np.float16)) | |||||
| >>> h = Tensor(np.random.rand(128, 128).astype(np.float16)) | |||||
| >>> c = Tensor(np.random.rand(128, 128).astype(np.float16)) | |||||
| >>> w = Tensor(np.random.rand(512, 256, 1, 1).astype(np.float16)) | |||||
| >>> b = Tensor(np.random.rand(512, 1, 1, 1).astype(np.float16)) | |||||
| >>> x = Tensor(np.random.rand(1, 32).astype(np.float16)) | |||||
| >>> h = Tensor(np.random.rand(1, 64).astype(np.float16)) | |||||
| >>> c = Tensor(np.random.rand(1, 64).astype(np.float16)) | |||||
| >>> w = Tensor(np.random.rand(96, 256).astype(np.float16)) | |||||
| >>> b = Tensor(np.random.rand(256, ).astype(np.float16)) | |||||
| >>> lstm = P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh') | >>> lstm = P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh') | ||||
| >>> lstm(x, h, c, w, b) | >>> lstm(x, h, c, w, b) | ||||
| """ | """ | ||||
| @@ -5375,42 +5381,38 @@ class BasicLSTMCell(PrimitiveWithInfer): | |||||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | ||||
| self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) | self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) | ||||
| self.activation = validator.check_string("activation", activation, ['tanh'], self.name) | self.activation = validator.check_string("activation", activation, ['tanh'], self.name) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape): | def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape): | ||||
| # (batch_size, input_size) | |||||
| validator.check_integer("x_shape", len(x_shape), 2, Rel.EQ, self.name) | |||||
| # h and c should be same shape | |||||
| validator.check_integer("h_shape", len(h_shape), 2, Rel.EQ, self.name) | |||||
| validator.check("h rank", len(h_shape), "c rank", len(c_shape), Rel.EQ, self.name) | |||||
| validator.check("h shape", h_shape, "c shape", c_shape, Rel.EQ, self.name) | |||||
| validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name) | |||||
| validator.check_integer("b rank", len(b_shape), 4, Rel.EQ, self.name) | |||||
| validator.check("w_shape[0]", w_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) | |||||
| validator.check("w_shape[1]", w_shape[1], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name) | |||||
| validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name) | |||||
| validator.check_integer("h rank", len(h_shape), 2, Rel.EQ, self.name) | |||||
| validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name) | |||||
| validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name) | |||||
| validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name) | |||||
| validator.check("x_shape[0]", x_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) | |||||
| validator.check("c_shape[0]", c_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) | |||||
| validator.check("c_shape[1]", c_shape[1], "h_shape[1]", h_shape[1], Rel.EQ, self.name) | |||||
| validator.check("w_shape[1]", w_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) | |||||
| validator.check("w_shape[0]", w_shape[0], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name) | |||||
| validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) | validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) | ||||
| ct_shape = c_shape | ct_shape = c_shape | ||||
| ht_shape = h_shape | |||||
| it_shape = h_shape | |||||
| jt_shape = h_shape | |||||
| ft_shape = h_shape | |||||
| ot_shape = h_shape | |||||
| tanhct_shape = h_shape | |||||
| ht_shape = c_shape | |||||
| it_shape = c_shape | |||||
| jt_shape = c_shape | |||||
| ft_shape = c_shape | |||||
| ot_shape = c_shape | |||||
| tanhct_shape = c_shape | |||||
| return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape) | return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape) | ||||
| def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype): | def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype): | ||||
| validator.check_subclass("x", x_dtype, [mstype.tensor], self.name) | |||||
| validator.check_subclass("h", h_dtype, [mstype.tensor], self.name) | |||||
| validator.check_subclass("c", c_dtype, [mstype.tensor], self.name) | |||||
| validator.check_subclass("w", w_dtype, [mstype.tensor], self.name) | |||||
| validator.check_subclass("b", b_dtype, [mstype.tensor], self.name) | |||||
| validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_type_name("b", b_dtype, [mstype.float16, mstype.float32], self.name) | |||||
| return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype) | |||||
| validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_tensor_type_same({"h_dtype": h_dtype}, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_tensor_type_same({"w_dtype": w_dtype}, [mstype.float16, mstype.float32], self.name) | |||||
| args = {"c_dtype": c_dtype, "b_dtype": b_dtype} | |||||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||||
| return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype) | |||||
| class InTopK(PrimitiveWithInfer): | class InTopK(PrimitiveWithInfer): | ||||
| @@ -735,7 +735,7 @@ TEST_F(TestConvert, TestConvertTensorError) { | |||||
| std::vector<int> dims2{2, 3, 4}; | std::vector<int> dims2{2, 3, 4}; | ||||
| auto type_id_2 = kNumberTypeFloat32; | auto type_id_2 = kNumberTypeFloat32; | ||||
| auto me_tensor_ptr_2 = std::make_shared<MeTensor>(type_id_2, dims2); | auto me_tensor_ptr_2 = std::make_shared<MeTensor>(type_id_2, dims2); | ||||
| ASSERT_EQ(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr); | |||||
| ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr); | |||||
| } | } | ||||
| TEST_F(TestConvert, TestUtilsConvertDataType) { | TEST_F(TestConvert, TestUtilsConvertDataType) { | ||||
| @@ -701,6 +701,16 @@ class ParallelConcatNet(nn.Cell): | |||||
| return self.parallel_concat((x1, x2)) | return self.parallel_concat((x1, x2)) | ||||
| class BasicLSTMCellNet(nn.Cell): | |||||
| """ BasicLSTMCellNet definition """ | |||||
| def __init__(self): | |||||
| super(BasicLSTMCellNet, self).__init__() | |||||
| self.lstm = P.BasicLSTMCell() | |||||
| def construct(self, x, h, c, w, b): | |||||
| return self.lstm(x, h, c, w, b) | |||||
| class EditDistance(nn.Cell): | class EditDistance(nn.Cell): | ||||
| def __init__(self, hypothesis_shape, truth_shape, normalize=True): | def __init__(self, hypothesis_shape, truth_shape, normalize=True): | ||||
| super(EditDistance, self).__init__() | super(EditDistance, self).__init__() | ||||
| @@ -1402,11 +1412,6 @@ test_case_nn_ops = [ | |||||
| 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], | 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], | ||||
| 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], | 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('BasicLSTMCell', { | |||||
| 'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'), | |||||
| 'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1], [512, 1, 1, 1]], | |||||
| 'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]], | |||||
| 'skip': []}), | |||||
| ('TopK', { | ('TopK', { | ||||
| 'block': P.TopK(), | 'block': P.TopK(), | ||||
| 'desc_const': [5], | 'desc_const': [5], | ||||
| @@ -2346,6 +2351,18 @@ test_case_other_ops = [ | |||||
| 'block': P.PopulationCount(), | 'block': P.PopulationCount(), | ||||
| 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.int16))], | 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.int16))], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('BasicLSTMCellNet', { | |||||
| 'block': BasicLSTMCellNet(), | |||||
| 'desc_inputs': [Tensor(np.random.rand(1, 32).astype(np.float16)), | |||||
| Tensor(np.random.rand(1, 64).astype(np.float16)), | |||||
| Tensor(np.random.rand(1, 64).astype(np.float16)), | |||||
| Tensor(np.random.rand(96, 256).astype(np.float16)), | |||||
| Tensor(np.random.rand(256, ).astype(np.float16))], | |||||
| 'desc_bprop': [Tensor(np.random.rand(1, 64).astype(np.float16)), | |||||
| Tensor(np.random.rand(1, 64).astype(np.float16)), | |||||
| Tensor(np.random.rand(1, 64).astype(np.float16)), | |||||
| Tensor(np.random.rand(1, 64).astype(np.float16)), | |||||
| Tensor(np.random.rand(1, 64).astype(np.float16))]}), | |||||
| ] | ] | ||||
| test_case_quant_ops = [ | test_case_quant_ops = [ | ||||