You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

perception.py 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import numpy as np
  2. import time
  3. import os
  4. import matplotlib.pyplot as plt
  5. import uctc.nn as nn
  6. from utils import parameter_data, Dataset
  7. use_graphics = False
  8. class PerceptronModel(object):
  9. def __init__(self, dimensions):
  10. """
  11. Initialize a new Perceptron instance.
  12. A perceptron classifies data points as either belonging to a particular
  13. class (+1) or not (-1). `dimensions` is the dimensionality of the data.
  14. For example, dimensions=2 would mean that the perceptron must classify
  15. 2D points.
  16. """
  17. self.w = nn.Parameter(parameter_data(dimensions, 1))
  18. def get_weights(self):
  19. """
  20. Return a Parameter instance with the current weights of the perceptron.
  21. """
  22. return self.w.data()
  23. def run(self, x):
  24. """
  25. Calculates the score assigned by the perceptron to a data point x.
  26. Inputs:
  27. x: a node with shape (1 x dimensions)
  28. Returns: a node containing a single number (the score)
  29. """
  30. "*** YOUR CODE HERE ***"
  31. out = nn.Linear(x, self.w)
  32. return out
  33. def get_prediction(self, x):
  34. """
  35. Calculates the predicted class for a single data point `x`.
  36. Returns: 1 or -1
  37. """
  38. "*** YOUR CODE HERE ***"
  39. score = self.run(x).data()[0]
  40. # score = np.array(x.data()).dot(np.array(self.w.data()))
  41. if score >= 0:
  42. return 1
  43. else:
  44. return -1
  45. def train(self, dataset):
  46. """
  47. Train the perceptron until convergence.
  48. """
  49. "*** YOUR CODE HERE ***"
  50. batch_size = 1
  51. while True:
  52. converged = True
  53. for x, y in dataset.iterate_once(batch_size):
  54. prediction = self.get_prediction(x)
  55. x = np.array(x.data(), dtype=np.float32)
  56. y = int(y.data()[0])
  57. # assert 0
  58. if prediction != y:
  59. # print(prediction, y)
  60. converged = False
  61. self.w.update(nn.pyarray_to_tensor(x), -y)
  62. # time.sleep(0.01)
  63. if converged:
  64. break
  65. class PerceptronDataset(Dataset):
  66. def __init__(self, model: PerceptronModel):
  67. points = 500
  68. x = np.hstack([np.random.randn(points, 2), np.ones((points, 1))])
  69. y = np.where(x[:, 0] + 2 * x[:, 1] - 1 >= 0, 1.0, -1.0)
  70. super().__init__(x, np.expand_dims(y, axis=1))
  71. self.model = model
  72. self.epoch = 0
  73. limits = np.array([-3.0, 3.0])
  74. if use_graphics:
  75. fig, ax = plt.subplots(1, 1)
  76. ax.set_xlim(limits)
  77. ax.set_ylim(limits)
  78. positive = ax.scatter(*x[y == 1, :-1].T, color="red", marker="+")
  79. negative = ax.scatter(*x[y == -1, :-1].T, color="blue", marker="_")
  80. line, = ax.plot([], [], color="black")
  81. text = ax.text(0.03, 0.97, "", transform=ax.transAxes, va="top")
  82. ax.legend([positive, negative], [1, -1])
  83. plt.show(block=False)
  84. self.fig = fig
  85. self.line = line
  86. self.text = text
  87. self.limits = limits
  88. self.last_update = time.time()
  89. def iterate_once(self, batch_size):
  90. self.epoch += 1
  91. for i, (x, y) in enumerate(super().iterate_once(batch_size)):
  92. yield x, y
  93. if time.time() - self.last_update > 0.001:
  94. w = self.model.get_weights()
  95. limits = self.limits
  96. print(f"epoch: {self.epoch}\npoint: {i * batch_size + 1}/{len(self.x)}\nweights: {w}")
  97. if use_graphics:
  98. if w[1] != 0:
  99. self.line.set_data(limits, (-w[0] * limits - w[2]) / w[1])
  100. elif w[0] != 0:
  101. self.line.set_data(np.full(2, -w[2] / w[0]), limits)
  102. else:
  103. self.line.set_data([], [])
  104. self.text.set_text(
  105. f"epoch: {self.epoch}\npoint: {i * batch_size + 1}/{len(self.x)}\nweights: {w}")
  106. self.fig.canvas.draw_idle()
  107. self.fig.canvas.start_event_loop(1e-3)
  108. self.last_update = time.time()
  109. model = PerceptronModel(3)
  110. dataset = PerceptronDataset(model)
  111. model.train(dataset)

计算机大作业