You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test.py 1.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. class Tensor:
  2. def __init__(self, data, depend=[]):
  3. """初始化"""
  4. self.data = data
  5. self.depend = depend
  6. self.grad = 0
  7. def __mul__(self, data):
  8. """乘法"""
  9. def grad_fn1(grad):
  10. return grad * data.data
  11. def grad_fn2(grad):
  12. return grad * self.data
  13. depend = [(self, grad_fn1), (data, grad_fn2)]
  14. new = Tensor(self.data * data.data, depend)
  15. return new
  16. def __rmul__(self, data):
  17. def grad_fn1(grad):
  18. return grad * data.data
  19. def grad_fn2(grad):
  20. return grad * self.data
  21. depend = [(self, grad_fn1), (data, grad_fn2)]
  22. new = Tensor(self.data * data.data, depend)
  23. return new
  24. def __add__(self, data):
  25. """加法"""
  26. def grad_fn(grad):
  27. return grad
  28. depend = [(self, grad_fn), (data, grad_fn)]
  29. new = Tensor(self.data * data.data, depend)
  30. return new
  31. def __radd__(self, data):
  32. def grad_fn(grad):
  33. return grad
  34. depend = [(self, grad_fn), (data, grad_fn)]
  35. new = Tensor(self.data * data.data, depend)
  36. return new
  37. def __repr__(self):
  38. return f"Tensor:{self.data}"
  39. def backward(self, grad=None):
  40. """
  41. 反向传播,需要递归计算
  42. """
  43. if grad == None:
  44. self.grad = 1
  45. else:
  46. # 这一步用于计算图中的分支
  47. self.grad += grad
  48. # 这一步是递归计算
  49. for tensor, grad_fn in self.depend:
  50. bw = grad_fn(self.grad)
  51. tensor.backward(bw)
  52. x = Tensor(4)
  53. f = x * x
  54. g = x * x
  55. y = f + g
  56. y.backward()
  57. print(x)
  58. print(y, g.grad, x.grad)

Edge : 一个开源的科学计算引擎