|
|
|
@@ -348,6 +348,7 @@ std::unique_ptr<session::LiteSession> NetTrain::CreateAndRunNetworkForTrain(cons |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "CreateTrainSession from model file" << filename.c_str(); |
|
|
|
std::cout << "CreateTrainSession from model file" << filename.c_str() << std::endl; |
|
|
|
session = std::unique_ptr<session::LiteSession>( |
|
|
|
session::TrainSession::CreateTrainSession(filename, &context, true, &train_cfg)); |
|
|
|
if (session == nullptr) { |
|
|
|
@@ -424,7 +425,7 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string |
|
|
|
if (train_session) { |
|
|
|
session = CreateAndRunNetworkForTrain(filename, bb_filename, context, train_cfg, epochs); |
|
|
|
if (session == nullptr) { |
|
|
|
MS_LOG(ERROR) << "CreateAndRunNetworkForInference failed."; |
|
|
|
MS_LOG(ERROR) << "CreateAndRunNetworkForTrain failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} else { |
|
|
|
|