| @@ -0,0 +1,154 @@ | |||||
| <html> | |||||
| <title>Visualizer</title> | |||||
| <head> | |||||
| <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no" /> | |||||
| </head> | |||||
| <script> | |||||
| window.onload = () => { | |||||
| var board = document.getElementById('board'); | |||||
| var fileInput = document.getElementById('fileInput'); | |||||
| var desc = document.getElementById('desc'); | |||||
| var hRange = document.getElementById('hRange'); | |||||
| var vRange = document.getElementById('vRange'); | |||||
| var lastColor = undefined; | |||||
| var lastElem = undefined; | |||||
| var scale = 1; | |||||
| var svg = undefined; | |||||
| var svgWidth = undefined; | |||||
| var svgHeight = undefined; | |||||
| var loadDesc = (svgElem) => { | |||||
| var mgeType = svgElem.attributes['mge:type']; | |||||
| if (mgeType === undefined) { | |||||
| return; | |||||
| } | |||||
| var elemList = []; | |||||
| for (attrName of svgElem.getAttributeNames()) { | |||||
| var prefix = 'mge:'; | |||||
| if (!attrName.startsWith(prefix)) { | |||||
| continue; | |||||
| } | |||||
| var elem = '<p>' + attrName.substr(prefix.length) + ': ' + svgElem.attributes[attrName].value + '</p>' | |||||
| elemList.push(elem); | |||||
| } | |||||
| desc.innerHTML = elemList.join(''); | |||||
| }; | |||||
| var selectElem = svgElem => { | |||||
| loadDesc(svgElem); | |||||
| lastColor = svgElem.attributes['fill'].value; | |||||
| lastElem = svgElem; | |||||
| svgElem.attributes['fill'].value = 'green'; | |||||
| }; | |||||
| var unselectLast = svgElem => { | |||||
| if (lastElem) { | |||||
| lastElem.attributes['fill'].value = lastColor; | |||||
| } | |||||
| lastElem = undefined; | |||||
| lastColor = undefined; | |||||
| }; | |||||
| function recLoadSVG(svgElem) { | |||||
| if (svgElem.children === undefined) { | |||||
| return; | |||||
| } | |||||
| svgElem.onmousedown = e => { | |||||
| var mgeType = svgElem.attributes['mge:type']; | |||||
| if (mgeType === undefined) { | |||||
| return; | |||||
| } | |||||
| unselectLast(); | |||||
| selectElem(svgElem); | |||||
| e.stopPropagation(); | |||||
| }; | |||||
| for (child of svgElem.children) { | |||||
| recLoadSVG(child); | |||||
| } | |||||
| } | |||||
| function loadSVG() { | |||||
| var file = fileInput.files[0]; | |||||
| var reader = new FileReader(); | |||||
| reader.readAsText(file, "UTF-8"); | |||||
| reader.onload = e => { | |||||
| board.innerHTML = '<p style="margin: 0;">' + e.target.result + '</p>'; | |||||
| svg = board.children[0].children[0]; | |||||
| svgWidth = svg.attributes['width'].value; | |||||
| svgHeight = svg.attributes['height'].value; | |||||
| for (child of board.children) { | |||||
| recLoadSVG(child); | |||||
| var svgInfo = child.attributes['svg:info']; | |||||
| if (svgInfo !== undefined) { | |||||
| var elemList = []; | |||||
| for (attrName of child.getAttributeNames()) { | |||||
| var prefix = 'svg:'; | |||||
| if (!attrName.startsWith(prefix)) { | |||||
| continue; | |||||
| } | |||||
| var elem = '<p>' + attrName.substr(prefix.length) + ': ' + child.attributes[attrName].value + '</p>' | |||||
| elemList.push(elem); | |||||
| } | |||||
| info.innerHTML = elemList.join(''); | |||||
| } | |||||
| } | |||||
| }; | |||||
| } | |||||
| function scaleBoard(x, y) { | |||||
| var transform = 'scale(' + x + ',' + y + ')'; | |||||
| svg.setAttribute('transform', transform); | |||||
| board.style['width'] = svgWidth * x; | |||||
| board.style['height'] = svgHeight * y; | |||||
| } | |||||
| function autoScaleBoard() { | |||||
| var hRangeValue = Math.sqrt(Number(hRange.value) / 10); | |||||
| var vRangeValue = Math.sqrt(Number(vRange.value) / 10); | |||||
| scaleBoard(Number(hRangeValue), Number(vRangeValue)); | |||||
| } | |||||
| fileInput.onchange = loadSVG; | |||||
| var zoomBoard = dScale => { | |||||
| scale *= dScale; | |||||
| scaleBoard(scale, scale); | |||||
| }; | |||||
| window.addEventListener('wheel', e => { | |||||
| console.log(e); | |||||
| if (e.ctrlKey) { | |||||
| e.preventDefault(); | |||||
| e.stopPropagation(); | |||||
| var factor = 1; | |||||
| if (e.deltaY < 0) { | |||||
| factor = 1.1; | |||||
| } else if (e.deltaY > 0) { | |||||
| factor = 1 / 1.1; | |||||
| } | |||||
| zoomBoard(factor); | |||||
| var newPageX = e.pageX * factor; | |||||
| var newPageY = e.pageY * factor; | |||||
| x = newPageX - e.x; | |||||
| y = newPageY - e.y; | |||||
| window.scrollTo({ | |||||
| top: y, | |||||
| left: x, | |||||
| }); | |||||
| console.log('scroll', [x, y]); | |||||
| } | |||||
| }, { 'passive': false }); | |||||
| }; | |||||
| </script> | |||||
| <body> | |||||
| <p id="desc" style="position: fixed;bottom: 0; background-color: white;">desc</p> | |||||
| <p id="info" style="position: fixed;top: 0; right: 0; background-color: white;">info</p> | |||||
| <p id="board" | |||||
| style="white-space: nowrap; display: flex; justify-content: center; align-content: center; align-items: center; margin: 0;opacity: 0.7;"> | |||||
| </p> | |||||
| <input type='file' id='fileInput' style="position: fixed; top: 0; background-color: white;"></input> | |||||
| </body> | |||||
| </html> | |||||
| @@ -492,6 +492,38 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | |||||
| const std::string& svg_name) { | |||||
| check_not_finalized(); | |||||
| auto& recorder = StaticMemRecorder::Instance(); | |||||
| recorder.active(); | |||||
| ExecContext exec_ctx{this}; | |||||
| // regist weights | |||||
| size_t addr_base = recorder.peak_mem_size(); | |||||
| size_t chunk_id = recorder.set_weight_chunk_id(); | |||||
| for (auto&& i : *(this->m_opr_seq)) { | |||||
| auto op = i->output(); | |||||
| for (auto&& j : op) { | |||||
| auto& mp = j->mem_plan(); | |||||
| if (mp.valid()) { | |||||
| auto& mc = mp.chunk(); | |||||
| if (mp.valid() && mc.mem_alloc_status.is_from_owner_var()) { | |||||
| recorder.regist_memory_chunk( | |||||
| {chunk_id++, mc.size(), 0, this->m_opr_seq->size(), | |||||
| addr_base, addr_base + mc.size(), 0, false, | |||||
| mc.owner_var->name()}); | |||||
| addr_base += mc.size(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| recorder.set_sum_mem_size(addr_base); | |||||
| mgb_assert(svg_name.length() > 4, "svg_name must be end with \".svg\"\n"); | |||||
| mgb_assert(svg_name.compare(svg_name.length() - 4, 4, ".svg") == 0, | |||||
| "svg_name must be end with \".svg\"\n"); | |||||
| recorder.show(svg_name); | |||||
| } | |||||
| AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | ||||
| do_wait(true); | do_wait(true); | ||||
| return *this; | return *this; | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "megbrain/comp_node_env.h" | #include "megbrain/comp_node_env.h" | ||||
| #include "megbrain/plugin/var_sanity_check.h" | #include "megbrain/plugin/var_sanity_check.h" | ||||
| #include "megbrain/utils/arith_helper.h" | #include "megbrain/utils/arith_helper.h" | ||||
| #include "megbrain/plugin/static_mem_record.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace cg { | namespace cg { | ||||
| @@ -169,6 +170,9 @@ public: | |||||
| } | } | ||||
| std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | ||||
| void get_static_memory_alloc_info( | |||||
| const std::string& svg_name = "static_mem_record.svg") override; | |||||
| }; | }; | ||||
| class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj { | class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj { | ||||
| @@ -178,9 +178,18 @@ bool SeqMemOptimizer::run_static_mem_alloc() { | |||||
| ThinHashMap<MemAllocPlan::Chunk*, MemChunkLifeInterval> chk2interval; | ThinHashMap<MemAllocPlan::Chunk*, MemChunkLifeInterval> chk2interval; | ||||
| // get all memory chunks | // get all memory chunks | ||||
| if (StaticMemRecorder::Instance().valid()) { | |||||
| StaticMemRecorder::Instance().clear_opr_seq(); | |||||
| } | |||||
| for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) { | for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) { | ||||
| OperatorNodeBase *opr = m_cur_seq_full->at(idx); | OperatorNodeBase *opr = m_cur_seq_full->at(idx); | ||||
| if (StaticMemRecorder::Instance().valid()) { | |||||
| StaticMemRecorder::Instance().regist_opr_seq( | |||||
| {idx, 0, opr->name()}); | |||||
| } | |||||
| auto &&dep_map = opr->node_prop().dep_map(); | auto &&dep_map = opr->node_prop().dep_map(); | ||||
| if (in_sys_alloc(opr)) { | if (in_sys_alloc(opr)) { | ||||
| @@ -349,6 +358,14 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node( | |||||
| chk.chunk->mem_alloc_status.set_static_offset( | chk.chunk->mem_alloc_status.set_static_offset( | ||||
| allocator->get_start_addr(&chk)); | allocator->get_start_addr(&chk)); | ||||
| } | } | ||||
| auto& recorder = StaticMemRecorder::Instance(); | |||||
| if (recorder.valid()) { | |||||
| for (size_t i = 0; i < chunks.size(); i++) { | |||||
| recorder.regist_memory_chunk_owner_var_name( | |||||
| i, chunks.at(i).chunk->owner_var->name()); | |||||
| } | |||||
| recorder.regist_peak_mem_size(size); | |||||
| } | |||||
| } | } | ||||
| return should_realloc; | return should_realloc; | ||||
| @@ -11,6 +11,7 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megbrain/plugin/static_mem_record.h" | |||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #include <cstddef> | #include <cstddef> | ||||
| @@ -120,6 +120,22 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() { | |||||
| check_result_and_calc_lower_bound(); | check_result_and_calc_lower_bound(); | ||||
| if (StaticMemRecorder::Instance().valid()) { | |||||
| StaticMemRecorder::Instance().clear_memory_chunk(); | |||||
| for (auto&& i : m_interval) { | |||||
| size_t overwrite_dest_id = 0; | |||||
| bool is_overwrite = !i->is_overwrite_root(); | |||||
| if (is_overwrite) { | |||||
| overwrite_dest_id = i->overwrite_dest_root()->id; | |||||
| } | |||||
| StaticMemRecorder::Instance().regist_memory_chunk( | |||||
| {i->id, i->size_orig, i->time_begin, i->time_end, | |||||
| i->addr_begin, i->addr_end(), overwrite_dest_id, | |||||
| is_overwrite, ""}); | |||||
| } | |||||
| } | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -190,6 +190,11 @@ class AsyncExecutable : public json::Serializable, | |||||
| m_user_data.get_user_data<OutputVarsUserData>(); | m_user_data.get_user_data<OutputVarsUserData>(); | ||||
| return (*(output_vars_pair.first))->get_output_vars(); | return (*(output_vars_pair.first))->get_output_vars(); | ||||
| } | } | ||||
| virtual void get_static_memory_alloc_info(const std::string& svg_name) { | |||||
| mgb_assert(svg_name.length() < 0, | |||||
| "can't call this function directly\n"); | |||||
| } | |||||
| }; | }; | ||||
| @@ -0,0 +1,319 @@ | |||||
| /** | |||||
| * \file src/plugin/impl/static_mem_record.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megbrain/plugin/static_mem_record.h" | |||||
| #include <fstream> | |||||
| #include <iostream> | |||||
| using namespace mgb; | |||||
| using namespace cg; | |||||
| namespace { | |||||
| #define SVG_WIDTH 20000.0 | |||||
| #define SVG_HEIGHT 15000.0 | |||||
| #define OPR_RECT_WIDTH 40.0 | |||||
| #define OPR_RECT_HEIGHT 20.0 | |||||
| const std::string rect = | |||||
| "<rect x=\"{}\" y=\"{}\" width=\"{}\" height=\"{}\" fill=\"{}\" " | |||||
| " {}></rect>"; | |||||
| const std::string text = "<text x=\"{}\" y=\"{}\" font-size=\"{}\">{}</text>"; | |||||
| const std::string polyline = | |||||
| "<polyline points=\"{}\" style=\"fill:none;stroke:{};stroke-width:{}\" " | |||||
| "/>"; | |||||
| const std::string opr_info = | |||||
| "mge:type=\"opr\" mge:id=\"{}\" mge:size=\"{}\" mge:name=\"{}\""; | |||||
| const std::string chunk_info = | |||||
| "mge:type=\"chunk\" mge:id=\"{}\" mge:time=\"{}\" mge:addr=\"{}\" " | |||||
| "mge:size=\"{}\" mge:owner_var_name=\"{}\""; | |||||
| const std::string animate = | |||||
| "<animate attributeName=\"opacity\" from=\"0\" to=\"1\" " | |||||
| "begin=\"{}.mouseover\" fill=\"freeze\" dur=\"1s\"/>\n<animate " | |||||
| "attributeName=\"opacity\" from=\"1\" to=\"0\" begin=\"{}.mouseout\" " | |||||
| "fill=\"freeze\" dur=\"1s\"/>"; | |||||
| std::string& replace_by_parameter(std::string& original_str, size_t index) { | |||||
| return original_str; | |||||
| } | |||||
| template <typename... Args> | |||||
| std::string& replace_by_parameter(std::string& original_str, size_t index, | |||||
| const std::string& parameter, | |||||
| const Args&... args) { | |||||
| index = original_str.find("{}", index); | |||||
| original_str.replace(index, 2, parameter); | |||||
| index += parameter.length(); | |||||
| replace_by_parameter(original_str, index, args...); | |||||
| return original_str; | |||||
| } | |||||
| std::string set_opr_info(std::string id, std::string size, std::string name, | |||||
| std::string info = opr_info) { | |||||
| return replace_by_parameter(info, 0, id, size, name); | |||||
| } | |||||
| std::string set_chunk_info(std::string id, std::string time, std::string addr, | |||||
| std::string size, std::string owner_var_name, | |||||
| std::string info = chunk_info) { | |||||
| return replace_by_parameter(info, 0, id, time, addr, size, owner_var_name); | |||||
| } | |||||
| std::string draw_rect(std::string x, std::string y, std::string widith, | |||||
| std::string height, std::string color, std::string info, | |||||
| std::string r = rect) { | |||||
| return replace_by_parameter(r, 0, x, y, widith, height, color, info); | |||||
| } | |||||
| std::string draw_text(std::string x, std::string y, std::string font_size, | |||||
| std::string txt, std::string t = text) { | |||||
| return replace_by_parameter(t, 0, x, y, font_size, txt); | |||||
| } | |||||
| std::string draw_polyline(std::string point_seq, std::string color, | |||||
| std::string width, std::string p = polyline) { | |||||
| return replace_by_parameter(p, 0, point_seq, color, width); | |||||
| } | |||||
| } // namespace | |||||
| void StaticMemRecorder::dump_svg(std::string svg_name) { | |||||
| float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, | |||||
| opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; | |||||
| float address_scale = 1; | |||||
| size_t opr_nr = m_opr_seq_recorder.size(); | |||||
| if (opr_nr * OPR_RECT_WIDTH > SVG_WIDTH) { | |||||
| svg_width = SVG_WIDTH; | |||||
| opr_rect_width = svg_width / opr_nr; | |||||
| opr_rect_height = opr_rect_width / 2; | |||||
| } else { | |||||
| opr_rect_width = OPR_RECT_WIDTH; | |||||
| svg_width = opr_nr * opr_rect_width; | |||||
| } | |||||
| if (m_sum_mem_size > SVG_HEIGHT) { | |||||
| svg_height = SVG_HEIGHT; | |||||
| address_scale = svg_height / m_sum_mem_size; | |||||
| } else { | |||||
| svg_height = m_sum_mem_size; | |||||
| } | |||||
| // Rescale | |||||
| float aspect_ratio = SVG_WIDTH / SVG_HEIGHT; | |||||
| if (svg_width / svg_height < 1) { | |||||
| svg_width = svg_height * aspect_ratio; | |||||
| opr_rect_width = svg_width / opr_nr; | |||||
| opr_rect_height = opr_rect_width / 2; | |||||
| } else if (svg_width / svg_height > aspect_ratio) { | |||||
| svg_height = svg_width / aspect_ratio; | |||||
| address_scale = svg_height / m_sum_mem_size; | |||||
| } | |||||
| svg_height = svg_height + opr_rect_height * 2; | |||||
| std::ofstream outfile; | |||||
| outfile.open(svg_name); | |||||
| outfile << "<?xml version=\"1.0\" standalone=\"no\"?>" << std::endl; | |||||
| outfile << "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN/\" " | |||||
| "\"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">" | |||||
| << std::endl; | |||||
| outfile << "<svg width=\"" + std::to_string(svg_width) + "\" height=\"" + | |||||
| std::to_string(svg_height) + | |||||
| "\" version=\"1.1\" " | |||||
| "xmlns=\"http://www.w3.org/2000/svg\">" | |||||
| << std::endl; | |||||
| float base_height = svg_height - opr_rect_height; | |||||
| std::string peak_mem_polyline = | |||||
| "0," + | |||||
| std::to_string(base_height - m_peak_mem_size * address_scale) + | |||||
| " " + std::to_string(m_opr_seq_recorder.size() * opr_rect_width) + | |||||
| "," + std::to_string(base_height - m_peak_mem_size * address_scale); | |||||
| std::string sum_mem_polyline = | |||||
| "0," + | |||||
| std::to_string(base_height - m_sum_mem_size * address_scale) + " " + | |||||
| std::to_string(m_opr_seq_recorder.size() * opr_rect_width) + "," + | |||||
| std::to_string(base_height - m_sum_mem_size * address_scale); | |||||
| std::string memory_polyline = ""; | |||||
| for (size_t i = 0; i < m_opr_seq_recorder.size(); i++) { | |||||
| auto&& opr = m_opr_seq_recorder.at(i); | |||||
| memory_polyline += | |||||
| std::to_string((i + 0.5) * opr_rect_width) + "," + | |||||
| std::to_string(base_height - opr.size * address_scale) + " "; | |||||
| outfile << draw_text(std::to_string(i * opr_rect_width), | |||||
| std::to_string(svg_height - opr_rect_height * 0.5), | |||||
| std::to_string(opr_rect_height * 0.5), | |||||
| "opr" + std::to_string(i)) | |||||
| << std::endl; | |||||
| std::string opr_info = | |||||
| set_opr_info( | |||||
| std::to_string(opr.id), | |||||
| std::to_string(opr.size) + "B(" + | |||||
| std::to_string(opr.size / 1024.0 / 1024.0) + | |||||
| "MiB)", | |||||
| opr.name) + | |||||
| " opacity=\"0\""; | |||||
| outfile << draw_rect(std::to_string(i * opr_rect_width), | |||||
| std::to_string(base_height), | |||||
| std::to_string(opr_rect_width), | |||||
| std::to_string(opr_rect_height), "white", opr_info) | |||||
| << std::endl; | |||||
| } | |||||
| for (size_t i = 0; i < m_memory_chunk_recorder.size(); i++) { | |||||
| auto&& chunk = m_memory_chunk_recorder.at(i); | |||||
| std::string chunk_info = set_chunk_info( | |||||
| std::to_string(chunk.id), | |||||
| "[" + std::to_string(chunk.time_begin) + "," + | |||||
| std::to_string(chunk.time_end) + ")", | |||||
| "[" + std::to_string(chunk.addr_begin) + "," + | |||||
| std::to_string(chunk.addr_end) + ")", | |||||
| std::to_string(chunk.addr_end - chunk.addr_begin) + "B(" + | |||||
| std::to_string((chunk.addr_end - chunk.addr_begin) / | |||||
| 1024.0 / 1024.0) + | |||||
| "MiB)", | |||||
| chunk.owner_var_name); | |||||
| outfile << draw_rect( | |||||
| std::to_string(chunk.time_begin * opr_rect_width), | |||||
| std::to_string(base_height - | |||||
| chunk.addr_end * address_scale), | |||||
| std::to_string((chunk.time_end - chunk.time_begin) * | |||||
| opr_rect_width), | |||||
| std::to_string((chunk.addr_end - chunk.addr_begin) * | |||||
| address_scale), | |||||
| "gray", chunk_info) | |||||
| << std::endl; | |||||
| outfile << draw_text(std::to_string(chunk.time_begin * opr_rect_width), | |||||
| std::to_string(base_height - | |||||
| chunk.addr_end * address_scale + 9), | |||||
| std::to_string(9), | |||||
| "chunk" + std::to_string(chunk.id)) | |||||
| << std::endl; | |||||
| } | |||||
| outfile << draw_text("0", | |||||
| std::to_string(base_height - | |||||
| m_peak_mem_size * address_scale + | |||||
| opr_rect_height * 0.5), | |||||
| std::to_string(opr_rect_height * 0.5), | |||||
| "peak_memory_size:" + std::to_string(m_peak_mem_size) + | |||||
| "B(" + | |||||
| std::to_string(m_peak_mem_size / 1024.0 / | |||||
| 1024.0) + | |||||
| "MiB)") | |||||
| << std::endl; | |||||
| outfile << draw_text("0", | |||||
| std::to_string(base_height - | |||||
| m_sum_mem_size * address_scale + | |||||
| opr_rect_height * 0.5), | |||||
| std::to_string(opr_rect_height * 0.5), | |||||
| "sum_memory_size:" + std::to_string(m_sum_mem_size) + | |||||
| "B(" + | |||||
| std::to_string(m_sum_mem_size / 1024.0 / | |||||
| 1024.0) + | |||||
| "MiB)") | |||||
| << std::endl; | |||||
| outfile << draw_polyline(memory_polyline, "blue", | |||||
| std::to_string(opr_rect_height * 0.1)) | |||||
| << std::endl; | |||||
| outfile << draw_polyline(peak_mem_polyline, "green", | |||||
| std::to_string(opr_rect_height * 0.1)) | |||||
| << std::endl; | |||||
| outfile << draw_polyline(sum_mem_polyline, "red", | |||||
| std::to_string(opr_rect_height * 0.1)) | |||||
| << std::endl; | |||||
| outfile << "<text svg:info=\"The abscissa represents the opr sequence, the " | |||||
| "ordinate represents the logical address.\" " | |||||
| "svg:chunk_time=\"[opra,oprb) means the chunk is created when " | |||||
| "opra execute and is freed before oprb\" " | |||||
| "svg:chunk_oner_var_name=\"var that first creates this " | |||||
| "chunk\"></text>" | |||||
| << std::endl; | |||||
| outfile << "</svg>" << std::endl; | |||||
| outfile.close(); | |||||
| } | |||||
| void StaticMemRecorder::show(std::string svg_name) { | |||||
| for (auto&& i : m_memory_chunk_recorder) { | |||||
| if (i.id >= m_weight_chunk_id) { | |||||
| break; | |||||
| } | |||||
| size_t begin = i.time_begin, end = i.time_end; | |||||
| if (i.is_overwrite) { | |||||
| begin++; | |||||
| } | |||||
| for (size_t j = begin; j < end; j++) { | |||||
| m_opr_seq_recorder.at(j).size += i.size_orig; | |||||
| } | |||||
| } | |||||
| // log peak memory size, where it is reached and which chunks constitute it. | |||||
| mgb_log("peak_mem_size = %zu\n", m_peak_mem_size); | |||||
| size_t max_size = 0; | |||||
| std::vector<size_t> opr_ids; | |||||
| for (auto&& i : m_opr_seq_recorder) { | |||||
| if (i.size == max_size) { | |||||
| opr_ids.push_back(i.id); | |||||
| } else if (i.size > max_size) { | |||||
| max_size = i.size; | |||||
| opr_ids.clear(); | |||||
| opr_ids.push_back(i.id); | |||||
| } | |||||
| } | |||||
| auto opr2chunk = get_chunk_construct(opr_ids); | |||||
| mgb_log("oprs reach the peak memory:\n"); | |||||
| for (auto&& i : opr_ids) { | |||||
| mgb_log("opr id = %zu\n", i); | |||||
| } | |||||
| mgb_log("More details:\n"); | |||||
| for (size_t i = 0; i < opr2chunk.size(); i++) { | |||||
| mgb_log("opr id = %zu\n", opr_ids.at(i)); | |||||
| if (i + 1 < opr2chunk.size() && | |||||
| opr2chunk.at(i) == opr2chunk.at(i + 1)) { | |||||
| continue; | |||||
| } | |||||
| for (size_t j = 0; j < opr2chunk.at(i).size(); j++) { | |||||
| auto&& chunk = m_memory_chunk_recorder.at(opr2chunk.at(i).at(j)); | |||||
| mgb_log("[memory_chunk_id=%zu, size=%zu B, " | |||||
| "[life_begin=%zu,life_end=%zu), owner_opr_name=%s]\n", | |||||
| chunk.id, chunk.size_orig, chunk.time_begin, chunk.time_end, | |||||
| m_opr_seq_recorder.at(chunk.time_begin).name.c_str()); | |||||
| } | |||||
| } | |||||
| dump_svg(svg_name); | |||||
| } | |||||
| std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct( | |||||
| std::vector<size_t> opr_ids) { | |||||
| std::vector<std::vector<size_t>> chunk_ids; | |||||
| chunk_ids.resize(opr_ids.size()); | |||||
| for (auto&& i : m_memory_chunk_recorder) { | |||||
| if (i.id >= m_weight_chunk_id) { | |||||
| break; | |||||
| } | |||||
| size_t begin = i.time_begin, end = i.time_end; | |||||
| if (i.is_overwrite) { | |||||
| begin = begin + 1; | |||||
| } | |||||
| if (opr_ids.front() >= end || opr_ids.back() < begin) { | |||||
| continue; | |||||
| } | |||||
| for (size_t k = 0; k < opr_ids.size(); k++) { | |||||
| if (opr_ids.at(k) >= end) { | |||||
| break; | |||||
| } else if (opr_ids.at(k) >= begin) { | |||||
| chunk_ids.at(k).push_back(i.id); | |||||
| } | |||||
| } | |||||
| } | |||||
| return chunk_ids; | |||||
| } | |||||
| @@ -0,0 +1,85 @@ | |||||
| /** | |||||
| * \file src/plugin/include/megbrain/plugin/static_mem_record.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megbrain/utils/metahelper.h" | |||||
| namespace mgb { | |||||
| namespace cg { | |||||
| class StaticMemRecorder : public NonCopyableObj { | |||||
| public: | |||||
| static StaticMemRecorder& Instance() { | |||||
| static StaticMemRecorder StaticMemRecorder; | |||||
| return StaticMemRecorder; | |||||
| } | |||||
| struct opr_record { | |||||
| size_t id, size; | |||||
| std::string name; | |||||
| }; | |||||
| struct memory_chunk_record { | |||||
| size_t id, size_orig, time_begin, time_end, addr_begin, | |||||
| addr_end, overwrite_dest_id; | |||||
| bool is_overwrite; | |||||
| std::string owner_var_name; | |||||
| }; | |||||
| void active() { m_is_record = true; } | |||||
| bool valid() { return m_is_record; } | |||||
| void clear_opr_seq() { m_opr_seq_recorder.clear(); } | |||||
| void regist_opr_seq(opr_record opr) { m_opr_seq_recorder.push_back(opr); } | |||||
| void clear_memory_chunk() { m_memory_chunk_recorder.clear(); } | |||||
| void regist_memory_chunk(memory_chunk_record mcr) { | |||||
| m_memory_chunk_recorder.push_back(mcr); | |||||
| } | |||||
| void regist_memory_chunk_owner_var_name(size_t id, std::string name) { | |||||
| m_memory_chunk_recorder.at(id).owner_var_name = name; | |||||
| } | |||||
| void regist_peak_mem_size(size_t size) { m_peak_mem_size = size; } | |||||
| const size_t& peak_mem_size() { return m_peak_mem_size; } | |||||
| void set_sum_mem_size(size_t size) { m_sum_mem_size = size; } | |||||
| const size_t& sum_mem_size() { return m_sum_mem_size; } | |||||
| const size_t& set_weight_chunk_id() { | |||||
| m_weight_chunk_id = m_memory_chunk_recorder.size(); | |||||
| return m_weight_chunk_id; | |||||
| } | |||||
| const size_t& weight_chunk_id() { return m_weight_chunk_id; } | |||||
| void dump_svg(std::string svg_name); | |||||
| void show(std::string svg_name); | |||||
| private: | |||||
| bool m_is_record = false; | |||||
| // All chunks after m_memory_chunk_recorder.at(m_weight_chunk_id) are | |||||
| // weights memory chunks | |||||
| size_t m_peak_mem_size, m_sum_mem_size, m_weight_chunk_id; | |||||
| std::vector<opr_record> m_opr_seq_recorder; | |||||
| std::vector<memory_chunk_record> m_memory_chunk_recorder; | |||||
| std::vector<std::vector<size_t>> get_chunk_construct( | |||||
| std::vector<size_t> opr_ids); | |||||
| }; | |||||
| } // namespace cg | |||||
| } // namespace mgb | |||||