Browse Source

!14519 Add infer for Softmax in C++

From: @liangzhibo
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
pull/14519/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
07bf048857
2 changed files with 15 additions and 16 deletions
  1. +15
    -5
      mindspore/core/ops/softmax.cc
  2. +0
    -11
      mindspore/ops/operations/nn_ops.py

+ 15
- 5
mindspore/core/ops/softmax.cc View File

@@ -46,15 +46,24 @@ void Softmax::Init(const int64_t axis) {

abstract::ShapePtr SoftMaxInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto Softmax_prim = primitive->cast<PrimSoftmaxPtr>();
MS_EXCEPTION_IF_NULL(Softmax_prim);
auto op_name = Softmax_prim->name();
auto axis = Softmax_prim->get_axis();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name);
auto op_name = primitive->name();
auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis));
(void)CheckAndConvertUtils::CheckValue<size_t>("length of axis", axis.size(), kGreaterEqual, 1, op_name);
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
if (shape_map.empty()) {
// Scalar input, has no shape
return std::make_shared<abstract::Shape>(std::vector<int64_t>());
}
auto in_shape = shape_map[kShape];
auto min_shape = shape_map[kMinShape];
auto max_shape = shape_map[kMaxShape];
auto rank = SizeToLong(in_shape.size());
for (auto &item : axis) {
CheckAndConvertUtils::CheckInRange<int64_t>("axis", item, kIncludeLeft, {-rank, rank}, op_name);
}
if (min_shape.size() != 0 && max_shape.size() != 0) {
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
}
return std::make_shared<abstract::Shape>(in_shape);
}

@@ -71,6 +80,7 @@ AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primitiv
return std::make_shared<abstract::AbstractTensor>(SoftMaxInferType(primitive, input_args),
SoftMaxInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer, nullptr, true);
REGISTER_PRIMITIVE_C(kNameSoftmax, Softmax);
} // namespace ops
} // namespace mindspore

+ 0
- 11
mindspore/ops/operations/nn_ops.py View File

@@ -177,17 +177,6 @@ class Softmax(PrimitiveWithInfer):
for item in self.axis:
validator.check_value_type("item of axis", item, [int], self.name)

def infer_shape(self, logits):
validator.check_int(len(self.axis), 1, Rel.GE, "length of axis", self.name)
rank = len(logits)
for axis_v in self.axis:
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
return logits

def infer_dtype(self, logits):
validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name)
return logits


class LogSoftmax(PrimitiveWithInfer):
r"""


Loading…
Cancel
Save