|
|
|
@@ -0,0 +1,49 @@ |
|
|
|
import torch |
|
|
|
import moxing as mox |
|
|
|
import os |
|
|
|
import argparse |
|
|
|
import torch_npu # 确保安装了 torch_npu |
|
|
|
|
|
|
|
# 创建参数解析器 |
|
|
|
parser = argparse.ArgumentParser(description='Training script with output path option') |
|
|
|
parser.add_argument('--input', default=None, type=str, help='input, where the dataset is stored.') |
|
|
|
parser.add_argument('--output', type=str, default='', help='Output path for saving the result') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
# 检查 NPU 是否可用 |
|
|
|
try: |
|
|
|
import torch_npu |
|
|
|
|
|
|
|
if torch_npu.npu.is_available(): |
|
|
|
device = torch.device('npu') |
|
|
|
print("Using NPU for training.") |
|
|
|
else: |
|
|
|
device = torch.device('cpu') |
|
|
|
print("NPU is not available, using CPU for training.") |
|
|
|
except ImportError: |
|
|
|
device = torch.device('cpu') |
|
|
|
print("torch_npu module not found, using CPU for training.") |
|
|
|
|
|
|
|
# 模拟一个简单的训练过程 |
|
|
|
# 这里其实不涉及真正的模型训练,只是作为示例流程展示 |
|
|
|
# 输出 Hello World |
|
|
|
result = "Hello World" |
|
|
|
print(result) |
|
|
|
|
|
|
|
# 将结果保存到本地文件 |
|
|
|
local_result_path = 'hellworld.pth' |
|
|
|
with open(local_result_path, 'w') as f: |
|
|
|
f.write(result) |
|
|
|
|
|
|
|
# 根据参数确定 OBS 路径 |
|
|
|
obs_result_path = args.output if args.output else 'obs://nudt-cloudream2/cds/trainResult/model/hellworld.pth' |
|
|
|
# 如果传入的路径不是以斜杠结尾,则添加斜杠 |
|
|
|
if obs_result_path and not obs_result_path.endswith('/'): |
|
|
|
obs_result_path += '/' |
|
|
|
obs_result_path += 'hellworld.pth' |
|
|
|
|
|
|
|
try: |
|
|
|
mox.file.copy(local_result_path, obs_result_path) |
|
|
|
print(f"Result has been successfully saved to {obs_result_path}") |
|
|
|
except Exception as e: |
|
|
|
print(f"Failed to save result to OBS: {e}") |