|
|
|
@@ -31,22 +31,40 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
auto op_name = primitive->name(); |
|
|
|
int64_t axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); |
|
|
|
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; |
|
|
|
CheckAndConvertUtils::CheckInteger("one_hot infer", input_args.size(), kEqual, 4, op_name); |
|
|
|
MS_EXCEPTION_IF_NULL(input_args[0]); |
|
|
|
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); |
|
|
|
auto in_shape = shape_map[kShape]; |
|
|
|
auto max_shape = shape_map[kMinShape]; |
|
|
|
auto min_shape = shape_map[kMaxShape]; |
|
|
|
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name); |
|
|
|
MS_EXCEPTION_IF_NULL(input_args[1]); |
|
|
|
auto depth_val = GetValue<int64_t>(input_args[1]->BuildValue()); |
|
|
|
CheckAndConvertUtils::CheckInteger("depth", depth_val, kGreaterEqual, 0, op_name); |
|
|
|
if (axis >= 0) { |
|
|
|
in_shape.insert(in_shape.begin() + axis, depth_val); |
|
|
|
if (min_shape.size() == 0 || max_shape.size() == 0) { |
|
|
|
if (axis >= 0) { |
|
|
|
in_shape.insert(in_shape.begin() + axis, depth_val); |
|
|
|
} else { |
|
|
|
in_shape.push_back(depth_val); |
|
|
|
} |
|
|
|
} else { |
|
|
|
in_shape.push_back(depth_val); |
|
|
|
if (axis >= 0) { |
|
|
|
in_shape.insert(in_shape.begin() + axis, depth_val); |
|
|
|
min_shape.insert(min_shape.begin() + axis, depth_val); |
|
|
|
max_shape.insert(max_shape.begin() + axis, depth_val); |
|
|
|
} else { |
|
|
|
in_shape.push_back(depth_val); |
|
|
|
min_shape.push_back(depth_val); |
|
|
|
max_shape.push_back(depth_val); |
|
|
|
} |
|
|
|
} |
|
|
|
return std::make_shared<abstract::Shape>(in_shape); |
|
|
|
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape); |
|
|
|
} |
|
|
|
|
|
|
|
TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
auto op_name = prim->name(); |
|
|
|
CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32}, op_name); |
|
|
|
CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32, kInt64}, op_name); |
|
|
|
CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64}, op_name); |
|
|
|
std::map<std::string, TypePtr> args = {{"on_value", input_args[2]->BuildType()}, |
|
|
|
{"off_dtype", input_args[3]->BuildType()}}; |
|
|
|
@@ -58,6 +76,6 @@ AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const Primitive |
|
|
|
return std::make_shared<abstract::AbstractTensor>(OneHotInferType(primitive, input_args), |
|
|
|
OneHotInferShape(primitive, input_args)->shape()); |
|
|
|
} |
|
|
|
REGISTER_PRIMITIVE_C(kNameOneHot, OneHot); |
|
|
|
REGISTER_PRIMITIVE_EVAL_IMPL(OneHot, prim::kPrimOneHot, OneHotInfer, nullptr, true); |
|
|
|
} // namespace ops |
|
|
|
} // namespace mindspore |