Browse Source

!15286 Add one_hot infer

From: @liangzhibo
Reviewed-by: @ginfung,@ginfung,@zh_qh
Signed-off-by: @zh_qh
pull/15286/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
38924d4425
3 changed files with 29 additions and 32 deletions
  1. +25
    -7
      mindspore/core/ops/one_hot.cc
  2. +3
    -5
      mindspore/core/ops/one_hot.h
  3. +1
    -20
      mindspore/ops/operations/nn_ops.py

+ 25
- 7
mindspore/core/ops/one_hot.cc View File

@@ -31,22 +31,40 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
int64_t axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("one_hot infer", input_args.size(), kEqual, 4, op_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto in_shape = shape_map[kShape];
auto max_shape = shape_map[kMinShape];
auto min_shape = shape_map[kMaxShape];
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name);
MS_EXCEPTION_IF_NULL(input_args[1]);
auto depth_val = GetValue<int64_t>(input_args[1]->BuildValue());
CheckAndConvertUtils::CheckInteger("depth", depth_val, kGreaterEqual, 0, op_name);
if (axis >= 0) {
in_shape.insert(in_shape.begin() + axis, depth_val);
if (min_shape.size() == 0 || max_shape.size() == 0) {
if (axis >= 0) {
in_shape.insert(in_shape.begin() + axis, depth_val);
} else {
in_shape.push_back(depth_val);
}
} else {
in_shape.push_back(depth_val);
if (axis >= 0) {
in_shape.insert(in_shape.begin() + axis, depth_val);
min_shape.insert(min_shape.begin() + axis, depth_val);
max_shape.insert(max_shape.begin() + axis, depth_val);
} else {
in_shape.push_back(depth_val);
min_shape.push_back(depth_val);
max_shape.push_back(depth_val);
}
}
return std::make_shared<abstract::Shape>(in_shape);
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
}

TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto op_name = prim->name();
CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32}, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32, kInt64}, op_name);
CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64}, op_name);
std::map<std::string, TypePtr> args = {{"on_value", input_args[2]->BuildType()},
{"off_dtype", input_args[3]->BuildType()}};
@@ -58,6 +76,6 @@ AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const Primitive
return std::make_shared<abstract::AbstractTensor>(OneHotInferType(primitive, input_args),
OneHotInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_C(kNameOneHot, OneHot);
REGISTER_PRIMITIVE_EVAL_IMPL(OneHot, prim::kPrimOneHot, OneHotInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

+ 3
- 5
mindspore/core/ops/one_hot.h View File

@@ -25,19 +25,17 @@

namespace mindspore {
namespace ops {
constexpr auto kNameOneHot = "OneHot";
class OneHot : public PrimitiveC {
public:
OneHot() : PrimitiveC(kNameOneHot) { InitIOName({"indices", "depth", "on_value", "off_value"}, {"output"}); }
OneHot() : PrimitiveC(prim::kPrimOneHot->name()) {
InitIOName({"indices", "depth", "on_value", "off_value"}, {"output"});
}
~OneHot() = default;
MS_DECLARE_PARENT(OneHot, PrimitiveC);
void Init(const int64_t axis);
void set_axis(const int64_t axis);
int64_t get_axis() const;
};

AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimOneHotPtr = std::shared_ptr<OneHot>;
} // namespace ops
} // namespace mindspore


+ 1
- 20
mindspore/ops/operations/nn_ops.py View File

@@ -3228,7 +3228,7 @@ class ResizeBilinear(PrimitiveWithInfer):
return mstype.tensor_type(mstype.float32)


class OneHot(PrimitiveWithInfer):
class OneHot(Primitive):
r"""
Computes a one-hot tensor.

@@ -3279,25 +3279,6 @@ class OneHot(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['indices', 'depth', 'on_value', 'off_value'], outputs=['output'])
validator.check_value_type("axis", axis, [int], self.name)

def __infer__(self, indices, depth, on_value, off_value):
# check type
validator.check_tensor_dtype_valid("indices", indices['dtype'], (mstype.int32, mstype.int64), self.name)
validator.check_type_name("depth", depth['dtype'], mstype.int_type, self.name)
args = {"on_value": on_value['dtype'], "off_value": off_value['dtype']}
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)

# check shape
indices_shp = indices['shape']
validator.check_int_range(self.axis, -1, len(indices_shp), Rel.INC_BOTH, "axis", self.name)
depth_val = depth['value']
validator.check_non_negative_int(depth_val, "depth", self.name)
# create new dimension at end if self.axis is -1
_ = indices_shp.insert(self.axis, depth_val) if self.axis >= 0 else indices_shp.append(depth_val)

return {'shape': indices_shp,
'dtype': on_value['dtype'],
'value': None}


class Gelu(PrimitiveWithInfer):
"""


Loading…
Cancel
Save