diff --git a/mindspore/core/ops/one_hot.cc b/mindspore/core/ops/one_hot.cc index f559dce969..33b75da901 100644 --- a/mindspore/core/ops/one_hot.cc +++ b/mindspore/core/ops/one_hot.cc @@ -31,22 +31,40 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); int64_t axis = GetValue(primitive->GetAttr(kAxis)); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + CheckAndConvertUtils::CheckInteger("one_hot infer", input_args.size(), kEqual, 4, op_name); + MS_EXCEPTION_IF_NULL(input_args[0]); + auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); + auto in_shape = shape_map[kShape]; + auto max_shape = shape_map[kMinShape]; + auto min_shape = shape_map[kMaxShape]; CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name); + MS_EXCEPTION_IF_NULL(input_args[1]); auto depth_val = GetValue(input_args[1]->BuildValue()); CheckAndConvertUtils::CheckInteger("depth", depth_val, kGreaterEqual, 0, op_name); - if (axis >= 0) { - in_shape.insert(in_shape.begin() + axis, depth_val); + if (min_shape.size() == 0 || max_shape.size() == 0) { + if (axis >= 0) { + in_shape.insert(in_shape.begin() + axis, depth_val); + } else { + in_shape.push_back(depth_val); + } } else { - in_shape.push_back(depth_val); + if (axis >= 0) { + in_shape.insert(in_shape.begin() + axis, depth_val); + min_shape.insert(min_shape.begin() + axis, depth_val); + max_shape.insert(max_shape.begin() + axis, depth_val); + } else { + in_shape.push_back(depth_val); + min_shape.push_back(depth_val); + max_shape.push_back(depth_val); + } } - return std::make_shared(in_shape); + return std::make_shared(in_shape, min_shape, max_shape); } TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); auto op_name = prim->name(); - CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32}, op_name); + CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32, kInt64}, op_name); CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64}, op_name); std::map args = {{"on_value", input_args[2]->BuildType()}, {"off_dtype", input_args[3]->BuildType()}}; @@ -58,6 +76,6 @@ AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const Primitive return std::make_shared(OneHotInferType(primitive, input_args), OneHotInferShape(primitive, input_args)->shape()); } -REGISTER_PRIMITIVE_C(kNameOneHot, OneHot); +REGISTER_PRIMITIVE_EVAL_IMPL(OneHot, prim::kPrimOneHot, OneHotInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/one_hot.h b/mindspore/core/ops/one_hot.h index 36886cf261..b953224bb6 100644 --- a/mindspore/core/ops/one_hot.h +++ b/mindspore/core/ops/one_hot.h @@ -25,19 +25,17 @@ namespace mindspore { namespace ops { -constexpr auto kNameOneHot = "OneHot"; class OneHot : public PrimitiveC { public: - OneHot() : PrimitiveC(kNameOneHot) { InitIOName({"indices", "depth", "on_value", "off_value"}, {"output"}); } + OneHot() : PrimitiveC(prim::kPrimOneHot->name()) { + InitIOName({"indices", "depth", "on_value", "off_value"}, {"output"}); + } ~OneHot() = default; MS_DECLARE_PARENT(OneHot, PrimitiveC); void Init(const int64_t axis); void set_axis(const int64_t axis); int64_t get_axis() const; }; - -AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); using PrimOneHotPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 86723eed30..89fdc6fd72 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3228,7 +3228,7 @@ class ResizeBilinear(PrimitiveWithInfer): return mstype.tensor_type(mstype.float32) -class OneHot(PrimitiveWithInfer): +class OneHot(Primitive): r""" Computes a one-hot tensor. @@ -3279,25 +3279,6 @@ class OneHot(PrimitiveWithInfer): self.init_prim_io_names(inputs=['indices', 'depth', 'on_value', 'off_value'], outputs=['output']) validator.check_value_type("axis", axis, [int], self.name) - def __infer__(self, indices, depth, on_value, off_value): - # check type - validator.check_tensor_dtype_valid("indices", indices['dtype'], (mstype.int32, mstype.int64), self.name) - validator.check_type_name("depth", depth['dtype'], mstype.int_type, self.name) - args = {"on_value": on_value['dtype'], "off_value": off_value['dtype']} - validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) - - # check shape - indices_shp = indices['shape'] - validator.check_int_range(self.axis, -1, len(indices_shp), Rel.INC_BOTH, "axis", self.name) - depth_val = depth['value'] - validator.check_non_negative_int(depth_val, "depth", self.name) - # create new dimension at end if self.axis is -1 - _ = indices_shp.insert(self.axis, depth_val) if self.axis >= 0 else indices_shp.append(depth_val) - - return {'shape': indices_shp, - 'dtype': on_value['dtype'], - 'value': None} - class Gelu(PrimitiveWithInfer): """