|
|
|
@@ -389,22 +389,22 @@ OpParameter *PopulateDropoutParameter(const void *prim) { |
|
|
|
} |
|
|
|
|
|
|
|
OpParameter *PopulateDropoutGradParameter(const void *prim) { |
|
|
|
DropoutParameter *dropoutGrad_parameter = reinterpret_cast<DropoutParameter *>(malloc(sizeof(DropoutParameter))); |
|
|
|
if (dropoutGrad_parameter == nullptr) { |
|
|
|
DropoutParameter *dropoutgrad_parameter = reinterpret_cast<DropoutParameter *>(malloc(sizeof(DropoutParameter))); |
|
|
|
if (dropoutgrad_parameter == nullptr) { |
|
|
|
MS_LOG(ERROR) << "malloc Dropout Grad Parameter failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
memset(dropoutGrad_parameter, 0, sizeof(DropoutParameter)); |
|
|
|
memset(dropoutgrad_parameter, 0, sizeof(DropoutParameter)); |
|
|
|
auto primitive = static_cast<const schema::Primitive *>(prim); |
|
|
|
auto value = primitive->value_as_DropoutGrad(); |
|
|
|
dropoutGrad_parameter->op_parameter_.type_ = primitive->value_type(); |
|
|
|
dropoutGrad_parameter->ratio_ = value->ratio(); |
|
|
|
if (dropoutGrad_parameter->ratio_ < 0.f || dropoutGrad_parameter->ratio_ > 1.f) { |
|
|
|
MS_LOG(ERROR) << "Dropout Grad ratio must be between 0 to 1, got " << dropoutGrad_parameter->ratio_; |
|
|
|
free(dropoutGrad_parameter); |
|
|
|
dropoutgrad_parameter->op_parameter_.type_ = primitive->value_type(); |
|
|
|
dropoutgrad_parameter->ratio_ = value->keep_prob(); |
|
|
|
if (dropoutgrad_parameter->ratio_ < 0.f || dropoutgrad_parameter->ratio_ > 1.f) { |
|
|
|
MS_LOG(ERROR) << "Dropout Grad ratio must be between 0 to 1, got " << dropoutgrad_parameter->ratio_; |
|
|
|
free(dropoutgrad_parameter); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return reinterpret_cast<OpParameter *>(dropoutGrad_parameter); |
|
|
|
return reinterpret_cast<OpParameter *>(dropoutgrad_parameter); |
|
|
|
} |
|
|
|
|
|
|
|
OpParameter *PopulateArithmeticGradParameter(const void *prim) { |
|
|
|
|