GitOrigin-RevId: dfc69c3b3f
tags/v1.6.0-rc1
| @@ -12,6 +12,7 @@ | |||
| #include "./cg_impl_seq.h" | |||
| #include "megbrain/graph/exc_extra_info.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/utils/arith_helper.h" | |||
| using namespace mgb; | |||
| using namespace cg; | |||
| @@ -298,6 +299,9 @@ void ComputingGraphImpl::ComputingSequence::do_execute( | |||
| } | |||
| exec_ctx.perform(&m_exec_env); | |||
| #ifndef __IN_TEE_ENV__ | |||
| do_regist(); | |||
| #endif | |||
| } | |||
| void ComputingGraphImpl::ComputingSequence::preprocess(ExecContext* ctx) { | |||
| @@ -511,35 +515,42 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() { | |||
| } | |||
| #ifndef __IN_TEE_ENV__ | |||
| void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | |||
| const std::string& svg_name) { | |||
| check_not_finalized(); | |||
| const std::string& svg_name) const { | |||
| auto& recorder = StaticMemRecorder::Instance(); | |||
| recorder.active(); | |||
| ExecContext exec_ctx{this}; | |||
| recorder.set_svg_name(svg_name); | |||
| } | |||
| void ComputingGraphImpl::ComputingSequence::do_regist() const { | |||
| // regist weights | |||
| size_t addr_base = recorder.peak_mem_size(); | |||
| size_t chunk_id = recorder.set_weight_chunk_id(); | |||
| for (auto&& i : *(this->m_opr_seq)) { | |||
| auto op = i->output(); | |||
| for (auto&& j : op) { | |||
| auto& mp = j->mem_plan(); | |||
| if (mp.valid()) { | |||
| auto& mc = mp.chunk(); | |||
| if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { | |||
| recorder.regist_memory_chunk( | |||
| {chunk_id++, mc.size(), 0, this->m_opr_seq->size(), | |||
| addr_base, addr_base + mc.size(), 0, false, | |||
| mc.owner_var->name()}); | |||
| addr_base += mc.size(); | |||
| auto& recorder = StaticMemRecorder::Instance(); | |||
| if (recorder.valid()) { | |||
| size_t addr_base = recorder.peak_mem_size(); | |||
| size_t chunk_id = recorder.set_weight_chunk_id(); | |||
| for (auto&& i : *(this->m_opr_seq)) { | |||
| auto op = i->output(); | |||
| for (auto&& j : op) { | |||
| auto& mp = j->mem_plan(); | |||
| if (mp.valid()) { | |||
| auto& mc = mp.chunk(); | |||
| if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { | |||
| auto size = mgb::get_aligned_power2( | |||
| mc.size(), | |||
| j->comp_node().get_mem_addr_alignment()); | |||
| recorder.regist_memory_chunk( | |||
| {chunk_id++, size, 0, this->m_opr_seq->size(), | |||
| addr_base, addr_base + size, 0, false, | |||
| mc.owner_var->name()}); | |||
| addr_base += size; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| recorder.set_sum_mem_size(addr_base); | |||
| recorder.show(); | |||
| } | |||
| recorder.set_sum_mem_size(addr_base); | |||
| mgb_assert(svg_name.length() > 4, "svg_name must be end with \".svg\"\n"); | |||
| mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, | |||
| "svg_name must be end with \".svg\"\n"); | |||
| recorder.show(svg_name); | |||
| } | |||
| #endif | |||
| AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | |||
| @@ -174,7 +174,10 @@ public: | |||
| std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | |||
| #ifndef __IN_TEE_ENV__ | |||
| void get_static_memory_alloc_info( | |||
| const std::string& svg_name = "static_mem_record.svg") override; | |||
| const std::string& svg_name = | |||
| "static_mem_record.svg") const override; | |||
| void do_regist() const; | |||
| #endif | |||
| }; | |||
| @@ -195,7 +195,8 @@ class AsyncExecutable : public json::Serializable, | |||
| return (*(output_vars_pair.first))->get_output_vars(); | |||
| } | |||
| #ifndef __IN_TEE_ENV__ | |||
| virtual void get_static_memory_alloc_info(const std::string& svg_name) { | |||
| virtual void get_static_memory_alloc_info( | |||
| const std::string& svg_name) const { | |||
| mgb_assert(svg_name.length() < 0, | |||
| "can't call this function directly\n"); | |||
| } | |||
| @@ -86,7 +86,7 @@ std::string draw_polyline(std::string point_seq, std::string color, | |||
| } | |||
| } // namespace | |||
| void StaticMemRecorder::dump_svg(std::string svg_name) { | |||
| void StaticMemRecorder::dump_svg() { | |||
| float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, | |||
| opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; | |||
| float address_scale = 1; | |||
| @@ -120,7 +120,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { | |||
| svg_height = svg_height + opr_rect_height * 2; | |||
| std::ofstream outfile; | |||
| outfile.open(svg_name); | |||
| outfile.open(m_svg_name); | |||
| outfile << "<?xml version=\"1.0\" standalone=\"no\"?>" << std::endl; | |||
| outfile << "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN/\" " | |||
| "\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">" | |||
| @@ -243,7 +243,7 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { | |||
| outfile.close(); | |||
| } | |||
| void StaticMemRecorder::show(std::string svg_name) { | |||
| void StaticMemRecorder::show() { | |||
| for (auto&& i : m_memory_chunk_recorder) { | |||
| if (i.id >= m_weight_chunk_id) { | |||
| break; | |||
| @@ -291,7 +291,7 @@ void StaticMemRecorder::show(std::string svg_name) { | |||
| m_opr_seq_recorder.at(chunk.time_begin).name.c_str()); | |||
| } | |||
| } | |||
| dump_svg(svg_name); | |||
| dump_svg(); | |||
| } | |||
| std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct( | |||
| @@ -54,25 +54,38 @@ public: | |||
| void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; } | |||
| const size_t& peak_mem_size() { return m_peak_mem_size; } | |||
| const size_t& peak_mem_size() const { return m_peak_mem_size; } | |||
| void set_sum_mem_size(size_t size) { m_sum_mem_size = size; } | |||
| const size_t& sum_mem_size() { return m_sum_mem_size; } | |||
| const size_t& sum_mem_size() const { return m_sum_mem_size; } | |||
| const size_t& set_weight_chunk_id() { | |||
| m_weight_chunk_id = m_memory_chunk_recorder.size(); | |||
| return m_weight_chunk_id; | |||
| } | |||
| const size_t& weight_chunk_id() { return m_weight_chunk_id; } | |||
| const size_t& weight_chunk_id() const { return m_weight_chunk_id; } | |||
| void dump_svg(std::string svg_name); | |||
| void dump_svg(); | |||
| void show(std::string svg_name); | |||
| void show(); | |||
| void set_svg_name(const std::string& svg_name) { | |||
| mgb_assert(svg_name.length() > 4, | |||
| "svg_name must be end with \".svg\"\n"); | |||
| mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, | |||
| "svg_name must be end with \".svg\"\n"); | |||
| m_svg_name = svg_name; | |||
| } | |||
| const std::string& get_svg_name() const{ | |||
| return m_svg_name; | |||
| } | |||
| private: | |||
| bool m_is_record = false; | |||
| std::string m_svg_name; | |||
| // All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are | |||
| // weights memory chunks | |||
| size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id; | |||