Browse Source

fix mean and var shape in layernorm

tags/v0.3.0-alpha
zhaojichen 5 years ago
parent
commit
c622879a2f
1 changed files with 8 additions and 2 deletions
  1. +8
    -2
      mindspore/ccsrc/operator/prim_nn.cc

+ 8
- 2
mindspore/ccsrc/operator/prim_nn.cc View File

@@ -301,7 +301,7 @@ AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr

// begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1
ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis");
(void)CheckAxis(op_name, bna_ptr, -1, SizeToInt(input_rank) - 1);
int begin_norm_axis = CheckAxis(op_name, bna_ptr, -1, SizeToInt(input_rank) - 1);

ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis");
int begin_params_axis = CheckAxis(op_name, bpa_ptr, -1, SizeToInt(input_rank) - 1);
@@ -341,7 +341,13 @@ AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr
}

auto mean_var_shape_value = input_shape->shape();
mean_var_shape_value[input_rank - 1] = 1;
if (begin_norm_axis == -1) {
mean_var_shape_value[input_rank - 1] = 1;
} else {
for (size_t i = begin_norm_axis; i < input_rank; ++i) {
mean_var_shape_value[i] = 1;
}
}

auto mean = input_x->Broaden();
mean->set_shape(std::make_shared<Shape>(mean_var_shape_value));


Loading…
Cancel
Save