You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import numpy as np
  2. import uctc.nn as nn
  3. np.random.seed(42)
  4. def parameter_data(*shape):
  5. assert len(shape) == 2, (
  6. "Shape must have 2 dimensions, instead has {}".format(len(shape)))
  7. assert all(isinstance(dim, int) and dim > 0 for dim in shape), (
  8. "Shape must consist of positive integers, got {!r}".format(shape))
  9. limit = np.sqrt(3.0 / np.mean(shape))
  10. data = np.random.uniform(low=-limit, high=limit, size=shape).astype(np.float32)
  11. return data
  12. class Dataset(object):
  13. def __init__(self, x, y):
  14. assert isinstance(x, np.ndarray)
  15. assert isinstance(y, np.ndarray)
  16. assert np.issubdtype(x.dtype, np.floating)
  17. assert np.issubdtype(y.dtype, np.floating)
  18. assert x.ndim == 2
  19. assert y.ndim == 2
  20. assert x.shape[0] == y.shape[0]
  21. self.x = x
  22. self.y = y
  23. def iterate_once(self, batch_size):
  24. assert isinstance(batch_size, int) and batch_size > 0, (
  25. f"Batch size should be a positive integer, got {batch_size}")
  26. assert self.x.shape[0] % batch_size == 0, (
  27. f"Dataset size {self.x.shape[0]} is not divisible by batch size {batch_size}")
  28. index = 0
  29. while index < self.x.shape[0]:
  30. x = self.x[index:index + batch_size]
  31. y = self.y[index:index + batch_size]
  32. yield nn.Constant(x), nn.Constant(y)
  33. index += batch_size
  34. def iterate_forever(self, batch_size):
  35. while True:
  36. yield from self.iterate_once(batch_size)
  37. def get_validation_accuracy(self):
  38. raise NotImplementedError(
  39. "No validation data is available for this dataset. "
  40. "In this assignment, only the Digit Classification and Language "
  41. "Identification datasets have validation data.")