Browse Source

!10107 fix ComputeAccidentalHits example

From: @yanzhenxiang2020
Reviewed-by: @wuxuejian,@oacjiewen
Signed-off-by: @wuxuejian
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
5e94f71ba1
2 changed files with 4 additions and 3 deletions
  1. +1
    -1
      mindspore/core/abstract/prim_nn.cc
  2. +3
    -2
      mindspore/ops/operations/nn_ops.py

+ 1
- 1
mindspore/core/abstract/prim_nn.cc View File

@@ -529,7 +529,7 @@ AbstractBasePtr InferImplComputeAccidentalHits(const AnalysisEnginePtr &, const

auto shape = input->shape();
if (shape->shape().size() != 2) {
MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1.";
MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 2.";
}
ShapeVector indices_shape = {Shape::SHP_ANY};
ShapeVector min_shape = {1};


+ 3
- 2
mindspore/ops/operations/nn_ops.py View File

@@ -3425,8 +3425,9 @@ class ComputeAccidentalHits(PrimitiveWithCheck):
>>> sampler = ops.ComputeAccidentalHits(2)
>>> output1, output2, output3 = sampler(Tensor(x), Tensor(y))
>>> print(output1, output2, output3)
[0, 0, 1, 1, 2, 2], [1, 2, 0, 4, 3, 3],
[-3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38, -3.4028235+38]
[0 0 1 1 2 2]
[1 2 0 4 3 3]
[-3.4028235e+38 -3.4028235e+38 -3.4028235e+38 -3.4028235e+38 -3.4028235e+38 -3.4028235e+38]

"""



Loading…
Cancel
Save