You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task13_14.py 1.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import numpy as np
  2. import uctc.nn as nn
  3. tensor1 = np.random.rand(42, 48)
  4. tensor2 = nn.pyarray_to_tensor(tensor1)
  5. t_tensor1 = tensor1.transpose()
  6. t_tensor2 = tensor2.transpose()
  7. t_2data = t_tensor2.data()
  8. t_1data = t_tensor1.flatten().tolist()
  9. def is_close(x, y):
  10. return abs(x - y) < 1e-5
  11. for i in range(len(t_1data)):
  12. if not is_close(t_1data[i], t_2data[i]):
  13. print(f"\033[1;31mTask 13 Error: t1 data[{i}] != t2 data[{i}]\033[0m")
  14. exit(0)
  15. at2 = nn.argmax(tensor2, 0).data()
  16. at1 = np.argmax(tensor1, 0).flatten().tolist()
  17. for i in range(len(at1)):
  18. if not is_close(at1[i], at2[i]):
  19. print(f"\033[1;31mTask 14 Error: at1 data[{i}] != at2 data[{i}]\033[0m")
  20. exit(0)
  21. at4 = nn.argmax(tensor2, 1).data()
  22. at3 = np.argmax(tensor1, 1).flatten().tolist()
  23. for i in range(len(at1)):
  24. if not is_close(at1[i], at2[i]):
  25. print(f"\033[1;31mTask 14 Error: at3 data[{i}] != at4 data[{i}]\033[0m")
  26. exit(0)
  27. print(f"\033[1;32m[PASSED] Task 13-14 finished!\033[0m")

计算机大作业

Contributors (1)