Browse Source

npu的数据集和算法上传

master
zzysz@qq.com 9 months ago
parent
commit
7d0e9f8237
2 changed files with 49 additions and 0 deletions
  1. +49
    -0
      npu/helloworld.py
  2. BIN
      npu/input.zip

+ 49
- 0
npu/helloworld.py View File

@@ -0,0 +1,49 @@
import torch
import moxing as mox
import os
import argparse
import torch_npu # 确保安装了 torch_npu

# 创建参数解析器
parser = argparse.ArgumentParser(description='Training script with output path option')
parser.add_argument('--input', default=None, type=str, help='input, where the dataset is stored.')
parser.add_argument('--output', type=str, default='', help='Output path for saving the result')
args = parser.parse_args()

# 检查 NPU 是否可用
try:
import torch_npu

if torch_npu.npu.is_available():
device = torch.device('npu')
print("Using NPU for training.")
else:
device = torch.device('cpu')
print("NPU is not available, using CPU for training.")
except ImportError:
device = torch.device('cpu')
print("torch_npu module not found, using CPU for training.")

# 模拟一个简单的训练过程
# 这里其实不涉及真正的模型训练,只是作为示例流程展示
# 输出 Hello World
result = "Hello World"
print(result)

# 将结果保存到本地文件
local_result_path = 'hellworld.pth'
with open(local_result_path, 'w') as f:
f.write(result)

# 根据参数确定 OBS 路径
obs_result_path = args.output if args.output else 'obs://nudt-cloudream2/cds/trainResult/model/hellworld.pth'
# 如果传入的路径不是以斜杠结尾,则添加斜杠
if obs_result_path and not obs_result_path.endswith('/'):
obs_result_path += '/'
obs_result_path += 'hellworld.pth'

try:
mox.file.copy(local_result_path, obs_result_path)
print(f"Result has been successfully saved to {obs_result_path}")
except Exception as e:
print(f"Failed to save result to OBS: {e}")

BIN
npu/input.zip View File


Loading…
Cancel
Save