Browse Source

fix a bug of onehot

feature/build-system-rewrite
b00518648 4 years ago
parent
commit
4ac7e1c34a
1 changed files with 2 additions and 1 deletions
  1. +2
    -1
      mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc

+ 2
- 1
mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc View File

@@ -208,7 +208,8 @@ ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) {
}

std::vector<StrategyPtr> OneHotInfo::GenerateOpStrategies(int64_t stage_id) {
Shapes splittable_inputs = {{1, 1}, {}, {}};
Shape input0_split(outputs_shape_[0].size(), 1);
Shapes splittable_inputs = {input0_split, {}, {}};
std::vector<StrategyPtr> sp_vector;
if (inputs_shape_.size() != 3) {
MS_LOG(EXCEPTION) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size();


Loading…
Cancel
Save