|
|
|
@@ -38,12 +38,13 @@ int BatchNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> & |
|
|
|
} |
|
|
|
if (this->primitive_->value.value == nullptr) { |
|
|
|
auto attr = new (std::nothrow) schema::FusedBatchNormT(); |
|
|
|
attr->epsilon = GetValue<float>(prim.GetAttr("epsilon")); |
|
|
|
this->primitive_->value.value = attr; |
|
|
|
if (this->primitive_->value.value == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new primitiveT value failed"; |
|
|
|
if (attr == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new FusedBatchNormT failed"; |
|
|
|
delete this->primitive_; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
attr->epsilon = GetValue<float>(prim.GetAttr("epsilon")); |
|
|
|
this->primitive_->value.value = attr; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|