|
|
|
@@ -3410,7 +3410,7 @@ class MirrorPad(PrimitiveWithInfer): |
|
|
|
'value': None} |
|
|
|
|
|
|
|
|
|
|
|
class ComputeAccidentalHits(PrimitiveWithInfer): |
|
|
|
class ComputeAccidentalHits(PrimitiveWithCheck): |
|
|
|
""" |
|
|
|
Compute accidental hits of sampled classes which happen to match target classes. |
|
|
|
|
|
|
|
@@ -3455,17 +3455,18 @@ class ComputeAccidentalHits(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=['true_classes', 'sampled_candidates'], |
|
|
|
outputs=['indices', 'ids', 'weights']) |
|
|
|
validator.check_value_type("num_true", num_true, [int], self.name) |
|
|
|
validator.check_number("num_true", num_true, 1, Rel.GE, self.name) |
|
|
|
self.num_true = num_true |
|
|
|
|
|
|
|
def infer_shape(self, true_classes_shape, sampled_candidates_shape): |
|
|
|
validator.check("true_classes shape rank", len(true_classes_shape), "expect", 2, Rel.EQ, self.name) |
|
|
|
validator.check("sampled_candidates shape rank", len(sampled_candidates_shape), "expect", 1, Rel.EQ, self.name) |
|
|
|
validator.check_int(true_classes_shape[1], self.num_true, Rel.EQ, 'true_classes_shape', self.name) |
|
|
|
def check_shape(self, true_classes_shape, sampled_candidates_shape): |
|
|
|
validator.check_int(len(true_classes_shape), 2, Rel.EQ, 'dim of true_classes', self.name) |
|
|
|
validator.check_int(len(sampled_candidates_shape), 1, Rel.EQ, 'dim of sampled_candidates', self.name) |
|
|
|
validator.check("true_classes shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name) |
|
|
|
|
|
|
|
indices_len = -1 |
|
|
|
return (indices_len,), (indices_len,), (indices_len,) |
|
|
|
|
|
|
|
def infer_dtype(self, true_classes_type, sampled_candidates_type): |
|
|
|
def check_dtype(self, true_classes_type, sampled_candidates_type): |
|
|
|
validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name) |
|
|
|
validator.check_subclass("sampled_candidates_type", sampled_candidates_type, mstype.tensor, self.name) |
|
|
|
valid_types = (mstype.int32, mstype.int64) |
|
|
|
@@ -6107,13 +6108,13 @@ class CTCLoss(PrimitiveWithInfer): |
|
|
|
>>> ctc_loss = ops.CTCLoss() |
|
|
|
>>> loss, gradient = ctc_loss(inputs, labels_indices, labels_values, sequence_length) |
|
|
|
>>> print(loss) |
|
|
|
[0.69121575 0.5381993 ] |
|
|
|
[0.69121575 0.5381993] |
|
|
|
>>> print(gradient) |
|
|
|
[[[ 0.25831494 0.3623634 -0.62067937] |
|
|
|
[ 0.25187883 0.2921483 -0.5440271 ]] |
|
|
|
[[[0.25831494 0.3623634 -0.62067937] |
|
|
|
[0.25187883 0.2921483 -0.5440271]] |
|
|
|
|
|
|
|
[[ 0.43522435 0.24408469 0.07787037 ] |
|
|
|
[ 0.29642645 0.4232373 0.06138104 ]]] |
|
|
|
[[0.43522435 0.24408469 0.07787037] |
|
|
|
[0.29642645 0.4232373 0.06138104]]] |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
|