| @@ -155,19 +155,28 @@ class NoRepeatNGram(PrimitiveWithInfer): | |||||
| Examples: | Examples: | ||||
| >>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3) | >>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3) | ||||
| >>> state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2], | >>> state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2], | ||||
| [9, 3, 9, 5, 4, 1, 5]], | |||||
| [[4, 8, 6, 4, 5, 6, 4], | |||||
| [4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32) | |||||
| >>> log_probs = Tensor([[[0.75858542, 0.8437121 , 0.69025469, 0.79379992, 0.27400691, | |||||
| 0.84709179, 0.78771346, 0.68587179, 0.22943851, 0.17682976]], | |||||
| [[0.99401879, 0.77239773, 0.81973878, 0.32085208, 0.59944118, | |||||
| 0.3125177, 0.52604189, 0.77111461, 0.98443699, 0.71532898]]], dtype=mindspore.float32) | |||||
| ... [9, 3, 9, 5, 4, 1, 5]], | |||||
| ... [[4, 8, 6, 4, 5, 6, 4], | |||||
| ... [4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32) | |||||
| >>> log_probs = Tensor([[[0.7, 0.8, 0.6, 0.9, 0.2, 0.8, 0.4, 0.6, 0.2, 0.7], | |||||
| ... [0.4, 0.5, 0.6, 0.7, 0.8, 0.1, 0.9, 0.8, 0.7, 0.1]], | |||||
| ... [[0.9, 0.7, 0.6, 0.3, 0.5, 0.3, 0.5, 0.4, 0.8, 0.6], | |||||
| ... [0.5, 0.8, 0.8, 0.7, 0.7, 0.8, 0.2, 0.7, 0.9, 0.7]]], dtype=mindspore.float32) | |||||
| >>> output = no_repeat_ngram(state_seq, log_probs) | >>> output = no_repeat_ngram(state_seq, log_probs) | ||||
| >>> print(output) | >>> print(output) | ||||
| [[[0.75858542 -3.4028235e+38 0.69025469 0.79379992 0.27400691 | |||||
| -3.4028235e+38 0.78771346 0.68587179 0.22943851 0.17682976]] | |||||
| [[0.99401879 0.77239773 0.81973878 0.32085208 0.59944118 | |||||
| -3.4028235e+38 0.52604189 0.77111461 0.98443699 0.71532898]]] | |||||
| [[[ 6.9999999e-01 -3.4028235e+38 6.0000002e-01 8.9999998e-01 | |||||
| 2.0000000e-01 -3.4028235e+38 4.0000001e-01 6.0000002e-01 | |||||
| 2.0000000e-01 6.9999999e-01] | |||||
| [ 4.0000001e-01 5.0000000e-01 6.0000002e-01 6.9999999e-01 | |||||
| 8.0000001e-01 1.0000000e-01 8.9999998e-01 8.0000001e-01 | |||||
| 6.9999999e-01 1.0000000e-01]] | |||||
| [[ 8.9999998e-01 6.9999999e-01 6.0000002e-01 3.0000001e-01 | |||||
| 5.0000000e-01 -3.4028235e+38 5.0000000e-01 4.0000001e-01 | |||||
| 8.0000001e-01 6.0000002e-01] | |||||
| [ 5.0000000e-01 8.0000001e-01 8.0000001e-01 6.9999999e-01 | |||||
| 6.9999999e-01 8.0000001e-01 2.0000000e-01 6.9999999e-01 | |||||
| -3.4028235e+38 6.9999999e-01]]] | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| @@ -179,11 +188,11 @@ class NoRepeatNGram(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['state_seq', 'log_probs'], outputs=['log_probs']) | self.init_prim_io_names(inputs=['state_seq', 'log_probs'], outputs=['log_probs']) | ||||
| def infer_shape(self, seq_shape, log_shape): | def infer_shape(self, seq_shape, log_shape): | ||||
| validator.check_int(len(seq_shape), 3, Rel.EQ, "rank_of_seq", self.name) | |||||
| validator.check_int(len(log_shape), 3, Rel.EQ, "rank_of_log", self.name) | |||||
| validator.check_int(seq_shape[0], log_shape[0], Rel.EQ, "seq_shape shape[0]", self.name) | |||||
| validator.check_int(seq_shape[1], log_shape[1], Rel.EQ, "seq_shape shape[1]", self.name) | |||||
| validator.check_int(self.ngram_size, seq_shape[2] + 1, Rel.LE, "ngram_size", self.name) | |||||
| validator.check_int(len(seq_shape), 3, Rel.EQ, "rank of state_seq", self.name) | |||||
| validator.check_int(len(log_shape), 3, Rel.EQ, "rank of log_probs", self.name) | |||||
| validator.check("state_seq shape[0]", seq_shape[0], "log_probs shape[0]", log_shape[0], Rel.EQ, self.name) | |||||
| validator.check("state_seq shape[1]", seq_shape[1], "log_probs shape[1]", log_shape[1], Rel.EQ, self.name) | |||||
| validator.check("ngram_size", self.ngram_size, "state_seq shape[2] + 1", seq_shape[2] + 1, Rel.LE, self.name) | |||||
| return log_shape | return log_shape | ||||
| def infer_dtype(self, seq_type, log_type): | def infer_dtype(self, seq_type, log_type): | ||||