diff --git a/npu/helloworld.py b/npu/helloworld.py new file mode 100644 index 0000000..410d849 --- /dev/null +++ b/npu/helloworld.py @@ -0,0 +1,49 @@ +import torch +import moxing as mox +import os +import argparse +import torch_npu # 确保安装了 torch_npu + +# 创建参数解析器 +parser = argparse.ArgumentParser(description='Training script with output path option') +parser.add_argument('--input', default=None, type=str, help='input, where the dataset is stored.') +parser.add_argument('--output', type=str, default='', help='Output path for saving the result') +args = parser.parse_args() + +# 检查 NPU 是否可用 +try: + import torch_npu + + if torch_npu.npu.is_available(): + device = torch.device('npu') + print("Using NPU for training.") + else: + device = torch.device('cpu') + print("NPU is not available, using CPU for training.") +except ImportError: + device = torch.device('cpu') + print("torch_npu module not found, using CPU for training.") + +# 模拟一个简单的训练过程 +# 这里其实不涉及真正的模型训练,只是作为示例流程展示 +# 输出 Hello World +result = "Hello World" +print(result) + +# 将结果保存到本地文件 +local_result_path = 'hellworld.pth' +with open(local_result_path, 'w') as f: + f.write(result) + +# 根据参数确定 OBS 路径 +obs_result_path = args.output if args.output else 'obs://nudt-cloudream2/cds/trainResult/model/hellworld.pth' +# 如果传入的路径不是以斜杠结尾,则添加斜杠 +if obs_result_path and not obs_result_path.endswith('/'): + obs_result_path += '/' +obs_result_path += 'hellworld.pth' + +try: + mox.file.copy(local_result_path, obs_result_path) + print(f"Result has been successfully saved to {obs_result_path}") +except Exception as e: + print(f"Failed to save result to OBS: {e}") diff --git a/npu/input.zip b/npu/input.zip new file mode 100644 index 0000000..16ed551 Binary files /dev/null and b/npu/input.zip differ