| @@ -141,9 +141,13 @@ R"__usage__( | |||||
| level 2 the computing graph can be destructed to reduce memory usage. Read | level 2 the computing graph can be destructed to reduce memory usage. Read | ||||
| the doc of `ComputingGraph::Options::comp_node_seq_record_level` for more | the doc of `ComputingGraph::Options::comp_node_seq_record_level` for more | ||||
| details. | details. | ||||
| )__usage__" | |||||
| #ifndef __IN_TEE_ENV__ | |||||
| R"__usage__( | |||||
| --get-static-mem-info <svgname> | --get-static-mem-info <svgname> | ||||
| Record the static graph's static memory info. | Record the static graph's static memory info. | ||||
| )__usage__" | )__usage__" | ||||
| #endif | |||||
| #if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
| R"__usage__( | R"__usage__( | ||||
| --full-run | --full-run | ||||
| @@ -538,7 +542,9 @@ struct Args { | |||||
| #endif | #endif | ||||
| bool reproducible = false; | bool reproducible = false; | ||||
| std::string fast_run_cache_path; | std::string fast_run_cache_path; | ||||
| #ifndef __IN_TEE_ENV__ | |||||
| std::string static_mem_svg_path; | std::string static_mem_svg_path; | ||||
| #endif | |||||
| bool copy_to_host = false; | bool copy_to_host = false; | ||||
| int nr_run = 10; | int nr_run = 10; | ||||
| int nr_warmup = 1; | int nr_warmup = 1; | ||||
| @@ -797,9 +803,11 @@ void run_test_st(Args &env) { | |||||
| } | } | ||||
| auto func = env.load_ret.graph_compile(out_spec); | auto func = env.load_ret.graph_compile(out_spec); | ||||
| #ifndef __IN_TEE_ENV__ | |||||
| if (!env.static_mem_svg_path.empty()) { | if (!env.static_mem_svg_path.empty()) { | ||||
| func->get_static_memory_alloc_info(env.static_mem_svg_path); | func->get_static_memory_alloc_info(env.static_mem_svg_path); | ||||
| } | } | ||||
| #endif | |||||
| auto warmup = [&]() { | auto warmup = [&]() { | ||||
| printf("=== prepare: %.3fms; going to warmup\n", | printf("=== prepare: %.3fms; going to warmup\n", | ||||
| timer.get_msecs_reset()); | timer.get_msecs_reset()); | ||||
| @@ -1383,6 +1391,7 @@ Args Args::from_argv(int argc, char **argv) { | |||||
| graph_opt.comp_node_seq_record_level = 2; | graph_opt.comp_node_seq_record_level = 2; | ||||
| continue; | continue; | ||||
| } | } | ||||
| #ifndef __IN_TEE_ENV__ | |||||
| if (!strcmp(argv[i], "--get-static-mem-info")) { | if (!strcmp(argv[i], "--get-static-mem-info")) { | ||||
| ++i; | ++i; | ||||
| mgb_assert(i < argc, "value not given for --get-static-mem-info"); | mgb_assert(i < argc, "value not given for --get-static-mem-info"); | ||||
| @@ -1393,6 +1402,7 @@ Args Args::from_argv(int argc, char **argv) { | |||||
| ret.static_mem_svg_path.c_str()); | ret.static_mem_svg_path.c_str()); | ||||
| continue; | continue; | ||||
| } | } | ||||
| #endif | |||||
| #if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
| if (!strcmp(argv[i], "--fast-run")) { | if (!strcmp(argv[i], "--fast-run")) { | ||||
| ret.use_fast_run = true; | ret.use_fast_run = true; | ||||
| @@ -491,7 +491,7 @@ AsyncExecutable& ComputingGraphImpl::ComputingSequence::execute() { | |||||
| do_execute(nullptr); | do_execute(nullptr); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| #ifndef __IN_TEE_ENV__ | |||||
| void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | ||||
| const std::string& svg_name) { | const std::string& svg_name) { | ||||
| check_not_finalized(); | check_not_finalized(); | ||||
| @@ -523,7 +523,7 @@ void ComputingGraphImpl::ComputingSequence::get_static_memory_alloc_info( | |||||
| "svg_name must be end with \".svg\"\n"); | "svg_name must be end with \".svg\"\n"); | ||||
| recorder.show(svg_name); | recorder.show(svg_name); | ||||
| } | } | ||||
| #endif | |||||
| AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | AsyncExecutable& ComputingGraphImpl::ComputingSequence::wait() { | ||||
| do_wait(true); | do_wait(true); | ||||
| return *this; | return *this; | ||||
| @@ -170,9 +170,10 @@ public: | |||||
| } | } | ||||
| std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | std::unique_ptr<RecordedComputingSequence> as_recorded_seq(); | ||||
| #ifndef __IN_TEE_ENV__ | |||||
| void get_static_memory_alloc_info( | void get_static_memory_alloc_info( | ||||
| const std::string& svg_name = "static_mem_record.svg") override; | const std::string& svg_name = "static_mem_record.svg") override; | ||||
| #endif | |||||
| }; | }; | ||||
| class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj { | class ComputingGraphImpl::MegDNNDtorCheck : public NonCopyableObj { | ||||
| @@ -178,18 +178,19 @@ bool SeqMemOptimizer::run_static_mem_alloc() { | |||||
| ThinHashMap<MemAllocPlan::Chunk*, MemChunkLifeInterval> chk2interval; | ThinHashMap<MemAllocPlan::Chunk*, MemChunkLifeInterval> chk2interval; | ||||
| // get all memory chunks | // get all memory chunks | ||||
| #ifndef __IN_TEE_ENV__ | |||||
| if (StaticMemRecorder::Instance().valid()) { | if (StaticMemRecorder::Instance().valid()) { | ||||
| StaticMemRecorder::Instance().clear_opr_seq(); | StaticMemRecorder::Instance().clear_opr_seq(); | ||||
| } | } | ||||
| #endif | |||||
| for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) { | for (size_t idx = 0; idx < m_cur_seq_full->size(); ++ idx) { | ||||
| OperatorNodeBase *opr = m_cur_seq_full->at(idx); | OperatorNodeBase *opr = m_cur_seq_full->at(idx); | ||||
| #ifndef __IN_TEE_ENV__ | |||||
| if (StaticMemRecorder::Instance().valid()) { | if (StaticMemRecorder::Instance().valid()) { | ||||
| StaticMemRecorder::Instance().regist_opr_seq( | StaticMemRecorder::Instance().regist_opr_seq( | ||||
| {idx, 0, opr->name()}); | {idx, 0, opr->name()}); | ||||
| } | } | ||||
| #endif | |||||
| auto &&dep_map = opr->node_prop().dep_map(); | auto &&dep_map = opr->node_prop().dep_map(); | ||||
| if (in_sys_alloc(opr)) { | if (in_sys_alloc(opr)) { | ||||
| @@ -358,6 +359,7 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node( | |||||
| chk.chunk->mem_alloc_status.set_static_offset( | chk.chunk->mem_alloc_status.set_static_offset( | ||||
| allocator->get_start_addr(&chk)); | allocator->get_start_addr(&chk)); | ||||
| } | } | ||||
| #ifndef __IN_TEE_ENV__ | |||||
| auto& recorder = StaticMemRecorder::Instance(); | auto& recorder = StaticMemRecorder::Instance(); | ||||
| if (recorder.valid()) { | if (recorder.valid()) { | ||||
| for (size_t i = 0; i < chunks.size(); i++) { | for (size_t i = 0; i < chunks.size(); i++) { | ||||
| @@ -366,6 +368,7 @@ bool SeqMemOptimizer::run_static_mem_alloc_on_comp_node( | |||||
| } | } | ||||
| recorder.regist_peak_mem_size(size); | recorder.regist_peak_mem_size(size); | ||||
| } | } | ||||
| #endif | |||||
| } | } | ||||
| return should_realloc; | return should_realloc; | ||||
| @@ -119,7 +119,7 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() { | |||||
| do_solve(); | do_solve(); | ||||
| check_result_and_calc_lower_bound(); | check_result_and_calc_lower_bound(); | ||||
| #ifndef __IN_TEE_ENV__ | |||||
| if (StaticMemRecorder::Instance().valid()) { | if (StaticMemRecorder::Instance().valid()) { | ||||
| StaticMemRecorder::Instance().clear_memory_chunk(); | StaticMemRecorder::Instance().clear_memory_chunk(); | ||||
| for (auto&& i : m_interval) { | for (auto&& i : m_interval) { | ||||
| @@ -135,7 +135,7 @@ StaticMemAlloc& StaticMemAllocImplHelper::solve() { | |||||
| is_overwrite, ""}); | is_overwrite, ""}); | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -194,11 +194,12 @@ class AsyncExecutable : public json::Serializable, | |||||
| m_user_data.get_user_data<OutputVarsUserData>(); | m_user_data.get_user_data<OutputVarsUserData>(); | ||||
| return (*(output_vars_pair.first))->get_output_vars(); | return (*(output_vars_pair.first))->get_output_vars(); | ||||
| } | } | ||||
| #ifndef __IN_TEE_ENV__ | |||||
| virtual void get_static_memory_alloc_info(const std::string& svg_name) { | virtual void get_static_memory_alloc_info(const std::string& svg_name) { | ||||
| mgb_assert(svg_name.length() < 0, | mgb_assert(svg_name.length() < 0, | ||||
| "can't call this function directly\n"); | "can't call this function directly\n"); | ||||
| } | } | ||||
| #endif | |||||
| }; | }; | ||||
| @@ -14,13 +14,11 @@ | |||||
| #ifndef __IN_TEE_ENV__ | #ifndef __IN_TEE_ENV__ | ||||
| #include <fstream> | #include <fstream> | ||||
| #include <iostream> | #include <iostream> | ||||
| #endif | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace cg; | using namespace cg; | ||||
| namespace { | namespace { | ||||
| #ifndef __IN_TEE_ENV__ | |||||
| #define SVG_WIDTH 20000.0 | #define SVG_WIDTH 20000.0 | ||||
| #define SVG_HEIGHT 15000.0 | #define SVG_HEIGHT 15000.0 | ||||
| #define OPR_RECT_WIDTH 40.0 | #define OPR_RECT_WIDTH 40.0 | ||||
| @@ -86,13 +84,9 @@ std::string draw_polyline(std::string point_seq, std::string color, | |||||
| std::string width, std::string p = polyline) { | std::string width, std::string p = polyline) { | ||||
| return replace_by_parameter(p, 0, point_seq, color, width); | return replace_by_parameter(p, 0, point_seq, color, width); | ||||
| } | } | ||||
| #endif | |||||
| } // namespace | } // namespace | ||||
| void StaticMemRecorder::dump_svg(std::string svg_name) { | void StaticMemRecorder::dump_svg(std::string svg_name) { | ||||
| #ifdef __IN_TEE_ENV__ | |||||
| MGB_MARK_USED_VAR(svg_name); | |||||
| #else | |||||
| float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, | float svg_width = SVG_WIDTH, svg_height = SVG_HEIGHT, | ||||
| opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; | opr_rect_width = OPR_RECT_WIDTH, opr_rect_height = OPR_RECT_HEIGHT; | ||||
| float address_scale = 1; | float address_scale = 1; | ||||
| @@ -247,7 +241,6 @@ void StaticMemRecorder::dump_svg(std::string svg_name) { | |||||
| << std::endl; | << std::endl; | ||||
| outfile << "</svg>" << std::endl; | outfile << "</svg>" << std::endl; | ||||
| outfile.close(); | outfile.close(); | ||||
| #endif | |||||
| } | } | ||||
| void StaticMemRecorder::show(std::string svg_name) { | void StaticMemRecorder::show(std::string svg_name) { | ||||
| @@ -326,3 +319,4 @@ std::vector<std::vector<size_t>> StaticMemRecorder::get_chunk_construct( | |||||
| } | } | ||||
| return chunk_ids; | return chunk_ids; | ||||
| } | } | ||||
| #endif | |||||
| @@ -12,7 +12,7 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megbrain/utils/metahelper.h" | #include "megbrain/utils/metahelper.h" | ||||
| #ifndef __IN_TEE_ENV__ | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace cg { | namespace cg { | ||||
| @@ -83,3 +83,4 @@ private: | |||||
| }; | }; | ||||
| } // namespace cg | } // namespace cg | ||||
| } // namespace mgb | } // namespace mgb | ||||
| #endif | |||||