|
|
|
@@ -208,7 +208,8 @@ ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) { |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<StrategyPtr> OneHotInfo::GenerateOpStrategies(int64_t stage_id) { |
|
|
|
Shapes splittable_inputs = {{1, 1}, {}, {}}; |
|
|
|
Shape input0_split(outputs_shape_[0].size(), 1); |
|
|
|
Shapes splittable_inputs = {input0_split, {}, {}}; |
|
|
|
std::vector<StrategyPtr> sp_vector; |
|
|
|
if (inputs_shape_.size() != 3) { |
|
|
|
MS_LOG(EXCEPTION) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size(); |
|
|
|
|