Browse Source

fix shape of ComputeAccidentalHits

tags/v1.1.0
yanzhenxiang2020 5 years ago
parent
commit
e9eb1ebac8
7 changed files with 45 additions and 14 deletions
  1. +1
    -1
      mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc
  2. +2
    -0
      mindspore/core/abstract/infer_functions.h
  3. +26
    -0
      mindspore/core/abstract/prim_nn.cc
  4. +1
    -0
      mindspore/core/abstract/primitive_infer_map.cc
  5. +1
    -0
      mindspore/core/base/core_ops.h
  6. +12
    -11
      mindspore/ops/operations/nn_ops.py
  7. +2
    -2
      tests/st/ops/ascend/test_aicpu_ops/test_compute_accidental_hits.py

+ 1
- 1
mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc View File

@@ -28,7 +28,7 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
std::set<std::string> kComputeDepend = {"Unique"};
std::set<std::string> kComputeDepend = {"Unique", "ComputeAccidentalHits"};
AiCpuDynamicKernel::~AiCpuDynamicKernel() { AiCpuDynamicKernel::~AiCpuDynamicKernel() {
// free dev ptr // free dev ptr
if (ext_info_addr_dev_ == nullptr) { if (ext_info_addr_dev_ == nullptr) {


+ 2
- 0
mindspore/core/abstract/infer_functions.h View File

@@ -221,6 +221,8 @@ AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const Primiti
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplComputeAccidentalHits(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 26
- 0
mindspore/core/abstract/prim_nn.cc View File

@@ -519,5 +519,31 @@ AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &prim
} }
return std::make_shared<AbstractTensor>(arg->element(), std::make_shared<Shape>(result_shp)); return std::make_shared<AbstractTensor>(arg->element(), std::make_shared<Shape>(result_shp));
} }

AbstractBasePtr InferImplComputeAccidentalHits(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// inputs: true_classes, sampled_candidates
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);

auto shape = input->shape();
if (shape->shape().size() != 2) {
MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1.";
}
ShapeVector indices_shape = {Shape::SHP_ANY};
ShapeVector min_shape = {1};
ShapeVector max_shape = {shape->shape()[0] * shape->shape()[1]};

auto indices =
std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(indices_shape, min_shape, max_shape));

auto weights = std::make_shared<AbstractTensor>(kFloat32, indices_shape);
weights->set_shape(std::make_shared<Shape>(indices_shape, min_shape, max_shape));
// outputs: indices, ids, weights
AbstractBasePtrList elements = {indices, indices, weights};
return std::make_shared<AbstractTuple>(elements);
}

} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

+ 1
- 0
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -69,6 +69,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}}, {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}},
{prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}}, {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}},
{prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, {prim::kPrimUpdateCache, {InferImplUpdateCache, true}},
{prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}},
{prim::kPrimDiv, {InferImplDiv, true}}, {prim::kPrimDiv, {InferImplDiv, true}},
{prim::kPrimRealDiv, {InferImplRealDiv, true}}, {prim::kPrimRealDiv, {InferImplRealDiv, true}},
{prim::kPrimShape, {InferImplShape, false}}, {prim::kPrimShape, {InferImplShape, false}},


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -101,6 +101,7 @@ inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
inline const PrimitivePtr kPrimSubAndFilter = std::make_shared<Primitive>("SubAndFilter"); inline const PrimitivePtr kPrimSubAndFilter = std::make_shared<Primitive>("SubAndFilter");
inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared<Primitive>("MapCacheIdx"); inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared<Primitive>("MapCacheIdx");
inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache"); inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache");
inline const PrimitivePtr kPrimComputeAccidentalHits = std::make_shared<Primitive>("ComputeAccidentalHits");
inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("CacheSwapTable"); inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("CacheSwapTable");
inline const PrimitivePtr kPrimSlice = std::make_shared<Primitive>("Slice"); inline const PrimitivePtr kPrimSlice = std::make_shared<Primitive>("Slice");
inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");


+ 12
- 11
mindspore/ops/operations/nn_ops.py View File

@@ -3410,7 +3410,7 @@ class MirrorPad(PrimitiveWithInfer):
'value': None} 'value': None}




class ComputeAccidentalHits(PrimitiveWithInfer):
class ComputeAccidentalHits(PrimitiveWithCheck):
""" """
Compute accidental hits of sampled classes which happen to match target classes. Compute accidental hits of sampled classes which happen to match target classes.


@@ -3455,17 +3455,18 @@ class ComputeAccidentalHits(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['true_classes', 'sampled_candidates'], self.init_prim_io_names(inputs=['true_classes', 'sampled_candidates'],
outputs=['indices', 'ids', 'weights']) outputs=['indices', 'ids', 'weights'])
validator.check_value_type("num_true", num_true, [int], self.name) validator.check_value_type("num_true", num_true, [int], self.name)
validator.check_number("num_true", num_true, 1, Rel.GE, self.name)
self.num_true = num_true self.num_true = num_true


def infer_shape(self, true_classes_shape, sampled_candidates_shape):
validator.check("true_classes shape rank", len(true_classes_shape), "expect", 2, Rel.EQ, self.name)
validator.check("sampled_candidates shape rank", len(sampled_candidates_shape), "expect", 1, Rel.EQ, self.name)
validator.check_int(true_classes_shape[1], self.num_true, Rel.EQ, 'true_classes_shape', self.name)
def check_shape(self, true_classes_shape, sampled_candidates_shape):
validator.check_int(len(true_classes_shape), 2, Rel.EQ, 'dim of true_classes', self.name)
validator.check_int(len(sampled_candidates_shape), 1, Rel.EQ, 'dim of sampled_candidates', self.name)
validator.check("true_classes shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name)


indices_len = -1 indices_len = -1
return (indices_len,), (indices_len,), (indices_len,) return (indices_len,), (indices_len,), (indices_len,)


def infer_dtype(self, true_classes_type, sampled_candidates_type):
def check_dtype(self, true_classes_type, sampled_candidates_type):
validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name) validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name)
validator.check_subclass("sampled_candidates_type", sampled_candidates_type, mstype.tensor, self.name) validator.check_subclass("sampled_candidates_type", sampled_candidates_type, mstype.tensor, self.name)
valid_types = (mstype.int32, mstype.int64) valid_types = (mstype.int32, mstype.int64)
@@ -6107,13 +6108,13 @@ class CTCLoss(PrimitiveWithInfer):
>>> ctc_loss = ops.CTCLoss() >>> ctc_loss = ops.CTCLoss()
>>> loss, gradient = ctc_loss(inputs, labels_indices, labels_values, sequence_length) >>> loss, gradient = ctc_loss(inputs, labels_indices, labels_values, sequence_length)
>>> print(loss) >>> print(loss)
[0.69121575 0.5381993 ]
[0.69121575 0.5381993]
>>> print(gradient) >>> print(gradient)
[[[ 0.25831494 0.3623634 -0.62067937]
[ 0.25187883 0.2921483 -0.5440271 ]]
[[[0.25831494 0.3623634 -0.62067937]
[0.25187883 0.2921483 -0.5440271]]


[[ 0.43522435 0.24408469 0.07787037 ]
[ 0.29642645 0.4232373 0.06138104 ]]]
[[0.43522435 0.24408469 0.07787037]
[0.29642645 0.4232373 0.06138104]]]
""" """


@prim_attr_register @prim_attr_register


+ 2
- 2
tests/st/ops/ascend/test_aicpu_ops/test_compute_accidental_hits.py View File

@@ -40,8 +40,8 @@ def test_net():


output1_expect = np.array([0, 0, 1, 1, 2, 2]) output1_expect = np.array([0, 0, 1, 1, 2, 2])
output2_expect = np.array([1, 2, 0, 4, 3, 3]) output2_expect = np.array([1, 2, 0, 4, 3, 3])
output3_expect = np.array([-3.4028235+38, -3.4028235+38, -3.4028235+38,
-3.4028235+38, -3.4028235+38, -3.4028235+38]).astype(np.float32)
output3_expect = np.array([-3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]).astype(np.float32)
assert np.array_equal(output1.asnumpy(), output1_expect) assert np.array_equal(output1.asnumpy(), output1_expect)
assert np.array_equal(output2.asnumpy(), output2_expect) assert np.array_equal(output2.asnumpy(), output2_expect)
assert np.array_equal(output3.asnumpy(), output3_expect) assert np.array_equal(output3.asnumpy(), output3_expect)

Loading…
Cancel
Save