You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """train"""
  16. import argparse
  17. import numpy as np
  18. import mindspore.context as context
  19. from src.read_var import read_nc
  20. from src.GOMO import GOMO_init, GOMO, read_init
  21. parser = argparse.ArgumentParser(description='GOMO')
  22. parser.add_argument('--file_path', type=str, default=None, help='file path')
  23. parser.add_argument('--outputs_path', type=str, default=None, help='outputs path')
  24. parser.add_argument('--im', type=int, default=65, help='im size')
  25. parser.add_argument('--jm', type=int, default=49, help='jm size')
  26. parser.add_argument('--kb', type=int, default=21, help='kb size')
  27. parser.add_argument('--stencil_width', type=int, default=1, help='stencil width')
  28. parser.add_argument('--step', type=int, default=10, help='time step')
  29. args_gomo = parser.parse_args()
  30. context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False, enable_graph_kernel=True)
  31. if __name__ == "__main__":
  32. variable = read_nc(args_gomo.file_path)
  33. im = args_gomo.im
  34. jm = args_gomo.jm
  35. kb = args_gomo.kb
  36. stencil_width = args_gomo.stencil_width
  37. # variable init
  38. dx, dy, dz, uab, vab, elb, etb, sb, tb, ub, vb, dt, h, w, wubot, wvbot, vfluxb, utb, vtb, dhb, egb, vfluxf, z, zz, \
  39. dzz, cor, fsm = read_init(
  40. variable, im, jm, kb)
  41. # define grid and init variable update
  42. net_init = GOMO_init(im, jm, kb, stencil_width)
  43. init_res = net_init(dx, dy, dz, uab, vab, elb, etb, sb, tb, ub, vb, h, w, vfluxf, zz, fsm)
  44. for res_tensor in init_res:
  45. if isinstance(res_tensor, (list, tuple)):
  46. for rt in res_tensor:
  47. rt.data_sync(True)
  48. else:
  49. res_tensor.data_sync(True)
  50. ua, va, el, et, etf, d, dt, l, q2b, q2lb, kh, km, kq, aam, w, q2, q2l, t, s, u, v, cbc, rmean, rho, x_d, y_d, z_d\
  51. = init_res
  52. # define GOMO model
  53. Model = GOMO(im=im, jm=jm, kb=kb, stencil_width=stencil_width, variable=variable, x_d=x_d, y_d=y_d, z_d=z_d,
  54. q2b=q2b, q2lb=q2lb, aam=aam, cbc=cbc, rmean=rmean)
  55. # time step of GOMO Model
  56. for step in range(1, args_gomo.step+1):
  57. elf, etf, ua, uab, va, vab, el, elb, d, u, v, w, kq, km, kh, q2, q2l, tb, t, sb, s, rho, wubot, wvbot, ub, vb, \
  58. egb, etb, dt, dhb, utb, vtb, vfluxb, et, steps, vamax, q2b, q2lb = Model(
  59. etf, ua, uab, va, vab, el, elb, d, u, v, w, kq, km, kh, q2, q2l, tb, t, sb, s, rho,
  60. wubot, wvbot, ub, vb, egb, etb, dt, dhb, utb, vtb, vfluxb, et)
  61. vars_list = etf, ua, uab, va, vab, el, elb, d, u, v, w, kq, km, kh, q2, q2l, tb, t, sb, s, rho, wubot, wvbot, \
  62. ub, vb, egb, etb, dt, dhb, utb, vtb, vfluxb, et
  63. for var in vars_list:
  64. var.asnumpy()
  65. # save output
  66. if step % 5 == 0:
  67. np.save(args_gomo.outputs_path + "u_"+str(step)+".npy", u.asnumpy())
  68. np.save(args_gomo.outputs_path + "v_" + str(step) + ".npy", v.asnumpy())
  69. np.save(args_gomo.outputs_path + "t_" + str(step) + ".npy", t.asnumpy())
  70. np.save(args_gomo.outputs_path + "et_" + str(step) + ".npy", et.asnumpy())