|
|
|
@@ -109,39 +109,15 @@ class Measurement : public mindspore::TrainCallBack { |
|
|
|
unsigned int epochs_; |
|
|
|
}; |
|
|
|
|
|
|
|
// Definition of verbose callback function after forwarding operator. |
|
|
|
bool after_callback(const std::vector<mindspore::MSTensor> &after_inputs, |
|
|
|
const std::vector<mindspore::MSTensor> &after_outputs, |
|
|
|
const mindspore::MSCallBackParam &call_param) { |
|
|
|
printf("%s\n", call_param.node_name.c_str()); |
|
|
|
for (size_t i = 0; i < after_inputs.size(); i++) { |
|
|
|
int num2p = (after_inputs.at(i).ElementNum()); |
|
|
|
printf("in%zu(%d): ", i, num2p); |
|
|
|
if (num2p > kPrintNum) num2p = kPrintNum; |
|
|
|
if (after_inputs.at(i).DataType() == mindspore::DataType::kNumberTypeInt32) { |
|
|
|
auto d = reinterpret_cast<const int *>(after_inputs.at(i).Data().get()); |
|
|
|
for (int j = 0; j < num2p; j++) printf("%d, ", d[j]); |
|
|
|
} else { |
|
|
|
auto d = reinterpret_cast<const float *>(after_inputs.at(i).Data().get()); |
|
|
|
for (int j = 0; j < num2p; j++) printf("%f, ", d[j]); |
|
|
|
} |
|
|
|
printf("\n"); |
|
|
|
NetRunner::~NetRunner() { |
|
|
|
if (model_ != nullptr) { |
|
|
|
delete model_; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < after_outputs.size(); i++) { |
|
|
|
auto d = reinterpret_cast<const float *>(after_inputs.at(i).Data().get()); |
|
|
|
int num2p = (after_outputs.at(i).ElementNum()); |
|
|
|
printf("ou%zu(%d): ", i, num2p); |
|
|
|
if (num2p > kElem2Print) { |
|
|
|
num2p = kElem2Print; |
|
|
|
} |
|
|
|
for (int j = 0; j < num2p; j++) printf("%f, ", d[j]); |
|
|
|
printf("\n"); |
|
|
|
if (graph_ != nullptr) { |
|
|
|
delete graph_; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
NetRunner::~NetRunner() {} |
|
|
|
|
|
|
|
void NetRunner::InitAndFigureInputs() { |
|
|
|
auto context = std::make_shared<mindspore::Context>(); |
|
|
|
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>(); |
|
|
|
@@ -149,6 +125,8 @@ void NetRunner::InitAndFigureInputs() { |
|
|
|
context->MutableDeviceInfo().push_back(cpu_context); |
|
|
|
|
|
|
|
graph_ = new mindspore::Graph(); |
|
|
|
MS_ASSERT(graph_ != nullptr); |
|
|
|
|
|
|
|
auto status = mindspore::Serialization::Load(ms_file_, mindspore::kMindIR, graph_); |
|
|
|
if (status != mindspore::kSuccess) { |
|
|
|
std::cout << "Error " << status << " during serialization of graph " << ms_file_; |
|
|
|
@@ -161,6 +139,8 @@ void NetRunner::InitAndFigureInputs() { |
|
|
|
} |
|
|
|
|
|
|
|
model_ = new mindspore::Model(); |
|
|
|
MS_ASSERT(model_ != nullptr); |
|
|
|
|
|
|
|
status = model_->Build(mindspore::GraphCell(*graph_), context, cfg); |
|
|
|
if (status != mindspore::kSuccess) { |
|
|
|
std::cout << "Error " << status << " during build of model " << ms_file_; |
|
|
|
@@ -168,6 +148,7 @@ void NetRunner::InitAndFigureInputs() { |
|
|
|
} |
|
|
|
|
|
|
|
acc_metrics_ = std::shared_ptr<AccuracyMetrics>(new AccuracyMetrics); |
|
|
|
MS_ASSERT(acc_metrics_ != nullptr); |
|
|
|
model_->InitMetrics({acc_metrics_.get()}); |
|
|
|
|
|
|
|
auto inputs = model_->GetInputs(); |
|
|
|
|