|
|
|
@@ -82,7 +82,8 @@ Registry OneHotRegistry(schema::PrimitiveType_OneHot, OneHotCreator); |
|
|
|
|
|
|
|
namespace { |
|
|
|
constexpr size_t kOneHotInputNum = 4; |
|
|
|
} |
|
|
|
constexpr size_t kOneHotInputNumOpt = 3; |
|
|
|
} // namespace |
|
|
|
int OneHot::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { |
|
|
|
if (this->primitive_ == nullptr) { |
|
|
|
return RET_NULL_PTR; |
|
|
|
@@ -90,8 +91,10 @@ int OneHot::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outpu |
|
|
|
|
|
|
|
int axis = GetAxis(); |
|
|
|
// indices, depth, on_value, off_value |
|
|
|
if (inputs.size() != kOneHotInputNum) { |
|
|
|
MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum; |
|
|
|
// indices, depth, on_off_value(contain 2 values); |
|
|
|
if (inputs.size() != kOneHotInputNum && inputs.size() != kOneHotInputNumOpt) { |
|
|
|
MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum << " or " |
|
|
|
<< kOneHotInputNumOpt; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto depth_tensor = inputs.at(1); |
|
|
|
|