| @@ -32,6 +32,7 @@ namespace mindspore { | |||||
| // op name. Op which not exists in operator/ops.h, so define it's name here | // op name. Op which not exists in operator/ops.h, so define it's name here | ||||
| constexpr auto kUniqueOpName = "Unique"; | constexpr auto kUniqueOpName = "Unique"; | ||||
| constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits"; | constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits"; | ||||
| constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder"; | |||||
| constexpr auto kFour2FiveOpName = "Four2Five"; | constexpr auto kFour2FiveOpName = "Four2Five"; | ||||
| constexpr auto kFive2FourOpName = "Five2Four"; | constexpr auto kFive2FourOpName = "Five2Four"; | ||||
| constexpr auto kConv2DOpName = "Conv2D"; | constexpr auto kConv2DOpName = "Conv2D"; | ||||
| @@ -486,7 +487,7 @@ const std::set<std::string> kHWSpecialFormatSet = { | |||||
| const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; | const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; | ||||
| const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, | const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, | ||||
| kPadAndShiftOpName}; | |||||
| kPadAndShiftOpName, kCTCGreedyDecoderOpName}; | |||||
| const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; | const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; | ||||
| @@ -205,6 +205,8 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplCTCGreedyDecoder(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -486,6 +486,42 @@ AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &prim | |||||
| return std::make_shared<AbstractTuple>(elements); | return std::make_shared<AbstractTuple>(elements); | ||||
| } | } | ||||
| AbstractBasePtr InferImplCTCGreedyDecoder(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| // inputs: inputs, sequence_length | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 2); | |||||
| AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| auto shape = input->shape(); | |||||
| if (shape->shape().size() != 3) { | |||||
| MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 3."; | |||||
| } | |||||
| ShapeVector indices_shape = {Shape::SHP_ANY, 2}; | |||||
| ShapeVector min_shape = {1, 2}; | |||||
| ShapeVector max_shape = {shape->shape()[0] * shape->shape()[1], 2}; | |||||
| auto decoded_indices = | |||||
| std::make_shared<AbstractTensor>(kInt64, std::make_shared<Shape>(indices_shape, min_shape, max_shape)); | |||||
| ShapeVector values_shape = {Shape::SHP_ANY}; | |||||
| ShapeVector values_min_shape = {1}; | |||||
| ShapeVector values_max_shape = {shape->shape()[0] * shape->shape()[1]}; | |||||
| ShapePtr values_shapes = std::make_shared<Shape>(values_shape, values_min_shape, values_max_shape); | |||||
| auto decoded_values = std::make_shared<AbstractTensor>(kInt64, values_shapes); | |||||
| ShapeVector decoded_shape_shape = {2}; | |||||
| auto decoded_shape = std::make_shared<AbstractTensor>(kInt64, decoded_shape_shape); | |||||
| ShapeVector log_probability_shape = {shape->shape()[1], 1}; | |||||
| auto log_probability = | |||||
| std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(log_probability_shape)); | |||||
| // outputs: decoded_indices, decoded_values, decoded_shape, log_probability | |||||
| AbstractBasePtrList elements = {decoded_indices, decoded_values, decoded_shape, log_probability}; | |||||
| return std::make_shared<AbstractTuple>(elements); | |||||
| } | |||||
| AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| const std::string op_name = primitive->name(); | const std::string op_name = primitive->name(); | ||||
| @@ -120,6 +120,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}}, | {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}}, | ||||
| {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}}, | {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}}, | ||||
| {prim::kPrimSGD, {InferImplSGD, true}}, | {prim::kPrimSGD, {InferImplSGD, true}}, | ||||
| {prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, true}}, | |||||
| // Others | // Others | ||||
| {prim::kPrimIdentity, {InferImplIdentity, true}}, | {prim::kPrimIdentity, {InferImplIdentity, true}}, | ||||
| // Set impl to null as it will use PartialEvaluator; | // Set impl to null as it will use PartialEvaluator; | ||||
| @@ -160,6 +160,7 @@ inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared<Primitive | |||||
| inline const PrimitivePtr kPrimConv3DBackpropInput = std::make_shared<Primitive>("Conv3DBackpropInput"); | inline const PrimitivePtr kPrimConv3DBackpropInput = std::make_shared<Primitive>("Conv3DBackpropInput"); | ||||
| inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared<Primitive>("Conv3DBackpropFilter"); | inline const PrimitivePtr kPrimConv3DBackpropFilter = std::make_shared<Primitive>("Conv3DBackpropFilter"); | ||||
| inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative"); | inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared<Primitive>("DepthwiseConv2dNative"); | ||||
| inline const PrimitivePtr kPrimCTCGreedyDecoder = std::make_shared<Primitive>("CTCGreedyDecoder"); | |||||
| inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = | inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = | ||||
| std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter"); | std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropFilter"); | ||||
| inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = | inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = | ||||
| @@ -6195,7 +6195,7 @@ class CTCLoss(PrimitiveWithInfer): | |||||
| return inputs, inputs | return inputs, inputs | ||||
| class CTCGreedyDecoder(PrimitiveWithInfer): | |||||
| class CTCGreedyDecoder(PrimitiveWithCheck): | |||||
| """ | """ | ||||
| Performs greedy decoding on the logits given in inputs. | Performs greedy decoding on the logits given in inputs. | ||||
| @@ -6221,29 +6221,22 @@ class CTCGreedyDecoder(PrimitiveWithInfer): | |||||
| containing sequence log-probability, has the same type as `inputs`. | containing sequence log-probability, has the same type as `inputs`. | ||||
| Examples: | Examples: | ||||
| >>> class CTCGreedyDecoderNet(nn.Cell): | |||||
| ... def __init__(self): | |||||
| ... super(CTCGreedyDecoderNet, self).__init__() | |||||
| ... self.ctc_greedy_decoder = P.CTCGreedyDecoder() | |||||
| ... self.assert_op = ops.Assert(300) | |||||
| ... | |||||
| ... def construct(self, inputs, sequence_length): | |||||
| ... out = self.ctc_greedy_decoder(inputs,sequence_length) | |||||
| ... self.assert_op(True, (out[0], out[1], out[2], out[3])) | |||||
| ... return out[2] | |||||
| ... | |||||
| >>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32) | >>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32) | ||||
| >>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32) | >>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32) | ||||
| >>> net = CTCGreedyDecoderNet() | |||||
| >>> output = net(inputs, sequence_length) | |||||
| >>> print(output) | |||||
| >>> ctc_greedy_decoder = ops.CTCGreedyDecoder() | |||||
| >>> out1, out2, out3, out4 = ctc_greedy_decoder(inputs, sequence_length) | |||||
| >>> print(out1, out2, out3, out4) | |||||
| [[0 0] [0 1] [1 0]] | |||||
| [0 1 0] | |||||
| [2 2] | |||||
| [[-0.7443749] [0.18251707]] | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, merge_repeated=True): | def __init__(self, merge_repeated=True): | ||||
| self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name) | self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name) | ||||
| def infer_shape(self, inputs_shape, sequence_length_shape): | |||||
| def check_shape(self, inputs_shape, sequence_length_shape): | |||||
| validator.check_int(len(inputs_shape), 3, Rel.EQ, "inputs rank", self.name) | validator.check_int(len(inputs_shape), 3, Rel.EQ, "inputs rank", self.name) | ||||
| validator.check_int(len(sequence_length_shape), 1, Rel.EQ, "sequence_length rank", self.name) | validator.check_int(len(sequence_length_shape), 1, Rel.EQ, "sequence_length rank", self.name) | ||||
| validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size', | validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size', | ||||
| @@ -6255,7 +6248,7 @@ class CTCGreedyDecoder(PrimitiveWithInfer): | |||||
| log_probability_shape = [inputs_shape[1], 1] | log_probability_shape = [inputs_shape[1], 1] | ||||
| return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape | return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape | ||||
| def infer_dtype(self, inputs_dtype, sequence_length_dtype): | |||||
| def check_dtype(self, inputs_dtype, sequence_length_dtype): | |||||
| validator.check_tensor_dtype_valid("inputs_dtype", inputs_dtype, [mstype.float32, mstype.double], self.name) | validator.check_tensor_dtype_valid("inputs_dtype", inputs_dtype, [mstype.float32, mstype.double], self.name) | ||||
| validator.check_tensor_dtype_valid("sequence_length_dtype", sequence_length_dtype, [mstype.int32], self.name) | validator.check_tensor_dtype_valid("sequence_length_dtype", sequence_length_dtype, [mstype.int32], self.name) | ||||
| decoded_type = mstype.tensor_type(mstype.int64) | decoded_type = mstype.tensor_type(mstype.int64) | ||||