From: @linqingke Reviewed-by: @jjfeing,@xu-yfei Signed-off-by: @xu-yfeitags/v1.2.0-rc1
| @@ -22,7 +22,7 @@ from te.platform.cce_conf import te_set_version | |||
| from te.platform.fusion_manager import set_current_op_name | |||
| from te.platform.fusion_util import fusion_op, dump_fusion_json | |||
| from te.platform.parallel_compilation import init_multi_process_env, get_finished_compilation_task, \ | |||
| deinit_multi_process_env, dispatch_autotune_task, start_ga_multi_process | |||
| deinit_multi_process_env, dispatch_autotune_task, start_ga_multi_process, import_py_module | |||
| import auto_tune | |||
| from schedule_search.rl_online_tune import rl_tune_init, dispatch_fusion_tune_task, dispatch_single_tune_task, \ | |||
| rl_tune_deinit | |||
| @@ -48,6 +48,8 @@ class TbeTuner: | |||
| if os.environ.get("TUNE_DUMP_PATH") is not None: | |||
| self.offline_dump_path = os.getenv("TUNE_DUMP_PATH", "") | |||
| self._creating_custom_path(tune_mode) | |||
| self.fusion_need_sync = 0 | |||
| self.module_list = {} | |||
| def init_tune_interface(self, json_str, process_num): | |||
| """ | |||
| @@ -222,6 +224,24 @@ class TbeTuner: | |||
| log.info("GA Tune init success.") | |||
| return True | |||
| def sync_fusion_env(self): | |||
| """ | |||
| Sync fusion env | |||
| :return: None | |||
| """ | |||
| if self.fusion_need_sync == 0: | |||
| return | |||
| module_using = [] | |||
| for key, value in self.module_list.items(): | |||
| if value > 0: | |||
| module_using.append(str(key)) | |||
| self.module_list[key] = 0 | |||
| module_str = ",".join(module_using) | |||
| import_py_module(module_str) | |||
| self.fusion_need_sync = 0 | |||
| def rl_tune(self, task_id, op_json): | |||
| """ | |||
| RL tune for single op and fusion op | |||
| @@ -231,6 +251,7 @@ class TbeTuner: | |||
| """ | |||
| json_info = json.loads(op_json) | |||
| if "fusion_op" in json_info: | |||
| self.sync_fusion_env() | |||
| ret = self.fusion_rl_tune(task_id, json_info) | |||
| else: | |||
| ret = self.single_rl_tune(task_id, json_info) | |||
| @@ -244,6 +265,7 @@ class TbeTuner: | |||
| """ | |||
| json_info = json.loads(op_json) | |||
| if "fusion_op" in json_info: | |||
| self.sync_fusion_env() | |||
| self.fusion_ga_tune(task_id, json_info) | |||
| else: | |||
| self.single_ga_tune(task_id, json_info) | |||
| @@ -289,6 +311,9 @@ class TbeTuner: | |||
| l1size = 0 # todo need to verify | |||
| ret = dispatch_single_tune_task(graph_id, task_id, l1size, base_kernel, kernel_name, op_module_name, | |||
| op_module_name + "@" + op_module_name, op_type, op_type, op_args) | |||
| self.module_list[op_module_name] = 1 | |||
| self.fusion_need_sync += 1 | |||
| return ret, job_type | |||
| def get_op_module_names(self, json_info): | |||
| @@ -20,6 +20,7 @@ | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| @@ -28,6 +29,8 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| const std::map<std::string, size_t> kFormatIndexMap = {{"NCHW", 2}, {"HWCN", 0}, {"NHWC", 1}}; | |||
| template <typename T> | |||
| class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| public: | |||
| @@ -339,7 +342,16 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||
| void SetStrideAndDilation(const CNodePtr &kernel_node) { | |||
| std::vector<int64_t> stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "stride"); | |||
| std::vector<int64_t> dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "dilation"); | |||
| (void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_), | |||
| std::string format_me = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "format"); | |||
| auto iter = kFormatIndexMap.find(format_me); | |||
| if (iter == kFormatIndexMap.end()) { | |||
| MS_LOG(EXCEPTION) << "OriFormat is " << format_me << ", Please confirm that in {NCHW, HWCN, NHWC}."; | |||
| } | |||
| size_t h_index = iter->second; | |||
| if (stride_me.size() < h_index + 2) { | |||
| MS_LOG(EXCEPTION) << "Strides should greater than " << h_index + 1 << ", but got " << stride_me.size(); | |||
| } | |||
| (void)std::transform(stride_me.begin() + h_index, stride_me.begin() + h_index + 2, std::back_inserter(stride_), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| (void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| @@ -1981,7 +1981,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): | |||
| self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) | |||
| self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) | |||
| self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) | |||
| self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False) | |||
| self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) | |||
| self.add_prim_attr('stride', self.stride) | |||
| self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) | |||
| self.add_prim_attr('dilation', self.dilation) | |||