diff --git a/mindspore/ops/operations/inner_ops.py b/mindspore/ops/operations/inner_ops.py index d8bf114578..06c6235d80 100644 --- a/mindspore/ops/operations/inner_ops.py +++ b/mindspore/ops/operations/inner_ops.py @@ -155,19 +155,28 @@ class NoRepeatNGram(PrimitiveWithInfer): Examples: >>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3) >>> state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2], - [9, 3, 9, 5, 4, 1, 5]], - [[4, 8, 6, 4, 5, 6, 4], - [4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32) - >>> log_probs = Tensor([[[0.75858542, 0.8437121 , 0.69025469, 0.79379992, 0.27400691, - 0.84709179, 0.78771346, 0.68587179, 0.22943851, 0.17682976]], - [[0.99401879, 0.77239773, 0.81973878, 0.32085208, 0.59944118, - 0.3125177, 0.52604189, 0.77111461, 0.98443699, 0.71532898]]], dtype=mindspore.float32) + ... [9, 3, 9, 5, 4, 1, 5]], + ... [[4, 8, 6, 4, 5, 6, 4], + ... [4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32) + >>> log_probs = Tensor([[[0.7, 0.8, 0.6, 0.9, 0.2, 0.8, 0.4, 0.6, 0.2, 0.7], + ... [0.4, 0.5, 0.6, 0.7, 0.8, 0.1, 0.9, 0.8, 0.7, 0.1]], + ... [[0.9, 0.7, 0.6, 0.3, 0.5, 0.3, 0.5, 0.4, 0.8, 0.6], + ... [0.5, 0.8, 0.8, 0.7, 0.7, 0.8, 0.2, 0.7, 0.9, 0.7]]], dtype=mindspore.float32) >>> output = no_repeat_ngram(state_seq, log_probs) >>> print(output) - [[[0.75858542 -3.4028235e+38 0.69025469 0.79379992 0.27400691 - -3.4028235e+38 0.78771346 0.68587179 0.22943851 0.17682976]] - [[0.99401879 0.77239773 0.81973878 0.32085208 0.59944118 - -3.4028235e+38 0.52604189 0.77111461 0.98443699 0.71532898]]] + [[[ 6.9999999e-01 -3.4028235e+38 6.0000002e-01 8.9999998e-01 + 2.0000000e-01 -3.4028235e+38 4.0000001e-01 6.0000002e-01 + 2.0000000e-01 6.9999999e-01] + [ 4.0000001e-01 5.0000000e-01 6.0000002e-01 6.9999999e-01 + 8.0000001e-01 1.0000000e-01 8.9999998e-01 8.0000001e-01 + 6.9999999e-01 1.0000000e-01]] + + [[ 8.9999998e-01 6.9999999e-01 6.0000002e-01 3.0000001e-01 + 5.0000000e-01 -3.4028235e+38 5.0000000e-01 4.0000001e-01 + 8.0000001e-01 6.0000002e-01] + [ 5.0000000e-01 8.0000001e-01 8.0000001e-01 6.9999999e-01 + 6.9999999e-01 8.0000001e-01 2.0000000e-01 6.9999999e-01 + -3.4028235e+38 6.9999999e-01]]] """ @prim_attr_register @@ -179,11 +188,11 @@ class NoRepeatNGram(PrimitiveWithInfer): self.init_prim_io_names(inputs=['state_seq', 'log_probs'], outputs=['log_probs']) def infer_shape(self, seq_shape, log_shape): - validator.check_int(len(seq_shape), 3, Rel.EQ, "rank_of_seq", self.name) - validator.check_int(len(log_shape), 3, Rel.EQ, "rank_of_log", self.name) - validator.check_int(seq_shape[0], log_shape[0], Rel.EQ, "seq_shape shape[0]", self.name) - validator.check_int(seq_shape[1], log_shape[1], Rel.EQ, "seq_shape shape[1]", self.name) - validator.check_int(self.ngram_size, seq_shape[2] + 1, Rel.LE, "ngram_size", self.name) + validator.check_int(len(seq_shape), 3, Rel.EQ, "rank of state_seq", self.name) + validator.check_int(len(log_shape), 3, Rel.EQ, "rank of log_probs", self.name) + validator.check("state_seq shape[0]", seq_shape[0], "log_probs shape[0]", log_shape[0], Rel.EQ, self.name) + validator.check("state_seq shape[1]", seq_shape[1], "log_probs shape[1]", log_shape[1], Rel.EQ, self.name) + validator.check("ngram_size", self.ngram_size, "state_seq shape[2] + 1", seq_shape[2] + 1, Rel.LE, self.name) return log_shape def infer_dtype(self, seq_type, log_type):