GitOrigin-RevId: dfc69c3b3f
tags/v1.6.0-rc1
| @@ -12,6 +12,7 @@ | |||||
| #include "./cg_impl_seq.h" | #include "./cg_impl_seq.h" | ||||
| #include "megbrain/graph/exc_extra_info.h" | #include "megbrain/graph/exc_extra_info.h" | ||||
| #include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
| #include "megbrain/utils/arith_helper.h" | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace cg; | using namespace cg; | ||||
| @@ -298,6 +299,9 @@ void ComputingGraphImpl::ComputingSequence::do_execute( | |||||
| } | } | ||||
| exec_ctx.perform(&m_exec_env); | exec_ctx.perform(&m_exec_env); | ||||
| #ifndef __IN_TEE_ENV__ | |||||
| do_regist(); | |||||
| #endif | |||||
| } | } | ||||
| void ComputingGraphImpl::ComputingSequence::preprocess(ExecContext* ctx) { | void ComputingGraphImpl::ComputingSequence::preprocess(ExecContext* ctx) { | ||||
| @@ -511,35 +515,42 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() { | |||||
| } | } | ||||
| #ifndef __IN_TEE_ENV__ | #ifndef __IN_TEE_ENV__ | ||||
| void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | ||||
| const std::string& svg_name) { | |||||
| check_not_finalized(); | |||||
| const std::string& svg_name) const { | |||||
| auto& recorder = StaticMemRecorder::Instance(); | auto& recorder = StaticMemRecorder::Instance(); | ||||
| recorder.active(); | recorder.active(); | ||||
| ExecContext exec_ctx{this}; | |||||
| recorder.set_svg_name(svg_name); | |||||
| } | |||||
| void ComputingGraphImpl::ComputingSequence::do_regist() const { | |||||
| // regist weights | // regist weights | ||||
| size_t addr_base = recorder.peak_mem_size(); | |||||
| size_t chunk_id = recorder.set_weight_chunk_id(); | |||||
| for (auto&& i : *(this->m_opr_seq)) { | |||||
| auto op = i->output(); | |||||
| for (auto&& j : op) { | |||||
| auto& mp = j->mem_plan(); | |||||
| if (mp.valid()) { | |||||
| auto& mc = mp.chunk(); | |||||
| if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { | |||||
| recorder.regist_memory_chunk( | |||||
| {chunk_id++, mc.size(), 0, this->m_opr_seq->size(), | |||||
| addr_base, addr_base + mc.size(), 0, false, | |||||
| mc.owner_var->name()}); | |||||
| addr_base += mc.size(); | |||||
| auto& recorder = StaticMemRecorder::Instance(); | |||||
| if (recorder.valid()) { | |||||
| size_t addr_base = recorder.peak_mem_size(); | |||||
| size_t chunk_id = recorder.set_weight_chunk_id(); | |||||
| for (auto&& i : *(this->m_opr_seq)) { | |||||
| auto op = i->output(); | |||||
| for (auto&& j : op) { | |||||
| auto& mp = j->mem_plan(); | |||||
| if (mp.valid()) { | |||||
| auto& mc = mp.chunk(); | |||||
| if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { | |||||
| auto size = mgb::get_aligned_power2( | |||||
| mc.size(), | |||||
| j->comp_node().get_mem_addr_alignment()); | |||||
| recorder.regist_memory_chunk( | |||||
| {chunk_id++, size, 0, this->m_opr_seq->size(), | |||||
| addr_base, addr_base + size, 0, false, | |||||
| mc.owner_var->name()}); | |||||
| addr_base += size; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| recorder.set_sum_mem_size(addr_base); | |||||
| recorder.show(); | |||||
| } | } | ||||
| recorder.set_sum_mem_size(addr_base); | |||||
| mgb_assert(svg_name.length() > 4, "svg_name must be end with \".svg\"\n"); | |||||
| mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, | |||||
| "svg_name must be end with \".svg\"\n"); | |||||
| recorder.show(svg_name); | |||||
| } | } | ||||
| #endif | #endif | ||||
| AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | ||||
| @@ -174,7 +174,10 @@ public: | |||||
| std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | ||||
| #ifndef __IN_TEE_ENV__ | #ifndef __IN_TEE_ENV__ | ||||
| void get_static_memory_alloc_info( | void get_static_memory_alloc_info( | ||||
| const std::string& svg_name = "static_mem_record.svg") override; | |||||
| const std::string& svg_name = | |||||
| "static_mem_record.svg") const override; | |||||
| void do_regist() const; | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| @@ -195,7 +195,8 @@ class AsyncExecutable : public json::Serializable, | |||||
| return (*(output_vars_pair.first))->get_output_vars(); | return (*(output_vars_pair.first))->get_output_vars(); | ||||
| } | } | ||||
| #ifndef __IN_TEE_ENV__ | #ifndef __IN_TEE_ENV__ | ||||
| virtual void get_static_memory_alloc_info(const std::string& svg_name) { | |||||
| virtual void get_static_memory_alloc_info( | |||||
| const std::string& svg_name) const { | |||||
| mgb_assert(svg_name.length() < 0, | mgb_assert(svg_name.length() < 0, | ||||
| "can't call this function directly\n"); | "can't call this function directly\n"); | ||||
| } | } | ||||
| @@ -86,7 +86,7 @@ std::string draw_polyline(std::string point_seq, std::string color, | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| void StaticMemRecorder::dump_svg(std::string svg_name) { | |||||
| void StaticMemRecorder::dump_svg() { | |||||
| float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, | float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, | ||||
| opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; | opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; | ||||
| float address_scale = 1; | float address_scale = 1; | ||||
| @@ -120,7 +120,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { | |||||
| svg_height = svg_height + opr_rect_height * 2; | svg_height = svg_height + opr_rect_height * 2; | ||||
| std::ofstream outfile; | std::ofstream outfile; | ||||
| outfile.open(svg_name); | |||||
| outfile.open(m_svg_name); | |||||
| outfile << "<?xml version=\"1.0\" standalone=\"no\"?>" << std::endl; | outfile << "<?xml version=\"1.0\" standalone=\"no\"?>" << std::endl; | ||||
| outfile << "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN/\" " | outfile << "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN/\" " | ||||
| "\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">" | "\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">" | ||||
| @@ -243,7 +243,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { | |||||
| outfile.close(); | outfile.close(); | ||||
| } | } | ||||
| void StaticMemRecorder::show(std::string svg_name) { | |||||
| void StaticMemRecorder::show() { | |||||
| for (auto&& i : m_memory_chunk_recorder) { | for (auto&& i : m_memory_chunk_recorder) { | ||||
| if (i.id >= m_weight_chunk_id) { | if (i.id >= m_weight_chunk_id) { | ||||
| break; | break; | ||||
| @@ -291,7 +291,7 @@ void StaticMemRecorder::show(std::string svg_name) { | |||||
| m_opr_seq_recorder.at(chunk.time_begin).name.c_str()); | m_opr_seq_recorder.at(chunk.time_begin).name.c_str()); | ||||
| } | } | ||||
| } | } | ||||
| dump_svg(svg_name); | |||||
| dump_svg(); | |||||
| } | } | ||||
| std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct( | std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct( | ||||
| @@ -54,25 +54,38 @@ public: | |||||
| void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; } | void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; } | ||||
| const size_t& peak_mem_size() { return m_peak_mem_size; } | |||||
| const size_t& peak_mem_size() const { return m_peak_mem_size; } | |||||
| void set_sum_mem_size(size_t size) { m_sum_mem_size = size; } | void set_sum_mem_size(size_t size) { m_sum_mem_size = size; } | ||||
| const size_t& sum_mem_size() { return m_sum_mem_size; } | |||||
| const size_t& sum_mem_size() const { return m_sum_mem_size; } | |||||
| const size_t& set_weight_chunk_id() { | const size_t& set_weight_chunk_id() { | ||||
| m_weight_chunk_id = m_memory_chunk_recorder.size(); | m_weight_chunk_id = m_memory_chunk_recorder.size(); | ||||
| return m_weight_chunk_id; | return m_weight_chunk_id; | ||||
| } | } | ||||
| const size_t& weight_chunk_id() { return m_weight_chunk_id; } | |||||
| const size_t& weight_chunk_id() const { return m_weight_chunk_id; } | |||||
| void dump_svg(std::string svg_name); | |||||
| void dump_svg(); | |||||
| void show(std::string svg_name); | |||||
| void show(); | |||||
| void set_svg_name(const std::string& svg_name) { | |||||
| mgb_assert(svg_name.length() > 4, | |||||
| "svg_name must be end with \".svg\"\n"); | |||||
| mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, | |||||
| "svg_name must be end with \".svg\"\n"); | |||||
| m_svg_name = svg_name; | |||||
| } | |||||
| const std::string& get_svg_name() const{ | |||||
| return m_svg_name; | |||||
| } | |||||
| private: | private: | ||||
| bool m_is_record = false; | bool m_is_record = false; | ||||
| std::string m_svg_name; | |||||
| // All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are | // All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are | ||||
| // weights memory chunks | // weights memory chunks | ||||
| size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id; | size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id; | ||||