You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RNN.cs 21 kB

5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578
  1. using OneOf;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Reflection;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.ArgsDefinition.Rnn;
  7. using Tensorflow.Keras.Engine;
  8. using Tensorflow.Keras.Saving;
  9. using Tensorflow.Util;
  10. using Tensorflow.Common.Extensions;
  11. using System.Linq.Expressions;
  12. using Tensorflow.Keras.Utils;
  13. using Tensorflow.Common.Types;
  14. using System.Runtime.CompilerServices;
  15. // from tensorflow.python.distribute import distribution_strategy_context as ds_context;
  16. namespace Tensorflow.Keras.Layers.Rnn
  17. {
  18. /// <summary>
  19. /// Base class for recurrent layers.
  20. /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
  21. /// for details about the usage of RNN API.
  22. /// </summary>
  23. public class RNN : RnnBase
  24. {
  25. private RNNArgs _args;
  26. private object _input_spec = null; // or NoneValue??
  27. private object _state_spec = null;
  28. private Tensors _states = null;
  29. private object _constants_spec = null;
  30. private int _num_constants;
  31. protected IVariableV1 _kernel;
  32. protected IVariableV1 _bias;
  33. private IRnnCell _cell;
  34. protected IRnnCell Cell
  35. {
  36. get
  37. {
  38. return _cell;
  39. }
  40. init
  41. {
  42. _cell = value;
  43. _self_tracked_trackables.Add(_cell);
  44. }
  45. }
  46. public RNN(RNNArgs args) : base(PreConstruct(args))
  47. {
  48. _args = args;
  49. SupportsMasking = true;
  50. // if is StackedRnncell
  51. if (args.Cells != null)
  52. {
  53. Cell = new StackedRNNCells(new StackedRNNCellsArgs
  54. {
  55. Cells = args.Cells
  56. });
  57. }
  58. else
  59. {
  60. Cell = args.Cell;
  61. }
  62. // get input_shape
  63. _args = PreConstruct(args);
  64. _num_constants = 0;
  65. }
  66. // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...)
  67. // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape
  68. public Tensors States
  69. {
  70. get
  71. {
  72. if (_states == null)
  73. {
  74. // CHECK(Rinne): check if this is correct.
  75. var nested = Cell.StateSize.MapStructure<Tensor?>(x => null);
  76. _states = nested.AsNest().ToTensors();
  77. }
  78. return _states;
  79. }
  80. set { _states = value; }
  81. }
  82. private INestStructure<Shape> compute_output_shape(Shape input_shape)
  83. {
  84. var batch = input_shape[0];
  85. var time_step = input_shape[1];
  86. if (_args.TimeMajor)
  87. {
  88. (batch, time_step) = (time_step, batch);
  89. }
  90. // state_size is a array of ints or a positive integer
  91. var state_size = Cell.StateSize;
  92. if(state_size?.TotalNestedCount == 1)
  93. {
  94. state_size = new NestList<long>(state_size.Flatten().First());
  95. }
  96. Func<long, Shape> _get_output_shape = (flat_output_size) =>
  97. {
  98. var output_dim = new Shape(flat_output_size).as_int_list();
  99. Shape output_shape;
  100. if (_args.ReturnSequences)
  101. {
  102. if (_args.TimeMajor)
  103. {
  104. output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim));
  105. }
  106. else
  107. {
  108. output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim));
  109. }
  110. }
  111. else
  112. {
  113. output_shape = new Shape(new int[] { (int)batch }.concat(output_dim));
  114. }
  115. return output_shape;
  116. };
  117. Type type = Cell.GetType();
  118. PropertyInfo output_size_info = type.GetProperty("output_size");
  119. INestStructure<Shape> output_shape;
  120. if (output_size_info != null)
  121. {
  122. output_shape = Nest.MapStructure(_get_output_shape, Cell.OutputSize);
  123. }
  124. else
  125. {
  126. output_shape = new NestNode<Shape>(_get_output_shape(state_size.Flatten().First()));
  127. }
  128. if (_args.ReturnState)
  129. {
  130. Func<long, Shape> _get_state_shape = (flat_state) =>
  131. {
  132. var state_shape = new int[] { (int)batch }.concat(new Shape(flat_state).as_int_list());
  133. return new Shape(state_shape);
  134. };
  135. var state_shape = Nest.MapStructure(_get_state_shape, state_size);
  136. return new Nest<Shape>(new[] { output_shape, state_shape } );
  137. }
  138. else
  139. {
  140. return output_shape;
  141. }
  142. }
  143. private Tensors compute_mask(Tensors inputs, Tensors mask)
  144. {
  145. // Time step masks must be the same for each input.
  146. // This is because the mask for an RNN is of size [batch, time_steps, 1],
  147. // and specifies which time steps should be skipped, and a time step
  148. // must be skipped for all inputs.
  149. mask = nest.flatten(mask)[0];
  150. var output_mask = _args.ReturnSequences ? mask : null;
  151. if (_args.ReturnState)
  152. {
  153. var state_mask = new List<Tensor>();
  154. for (int i = 0; i < len(States); i++)
  155. {
  156. state_mask.Add(null);
  157. }
  158. return new List<Tensor> { output_mask }.concat(state_mask);
  159. }
  160. else
  161. {
  162. return output_mask;
  163. }
  164. }
  165. public override void build(KerasShapesWrapper input_shape)
  166. {
  167. input_shape = new KerasShapesWrapper(input_shape.Shapes[0]);
  168. InputSpec get_input_spec(Shape shape)
  169. {
  170. var input_spec_shape = shape.as_int_list();
  171. var (batch_index, time_step_index) = _args.TimeMajor ? (1, 0) : (0, 1);
  172. if (!_args.Stateful)
  173. {
  174. input_spec_shape[batch_index] = -1;
  175. }
  176. input_spec_shape[time_step_index] = -1;
  177. return new InputSpec(shape: input_spec_shape);
  178. }
  179. Shape get_step_input_shape(Shape shape)
  180. {
  181. // return shape[1:] if self.time_major else (shape[0],) + shape[2:]
  182. if (_args.TimeMajor)
  183. {
  184. return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray();
  185. }
  186. else
  187. {
  188. return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray());
  189. }
  190. }
  191. object get_state_spec(Shape shape)
  192. {
  193. var state_spec_shape = shape.as_int_list();
  194. // append bacth dim
  195. state_spec_shape = new int[] { -1 }.concat(state_spec_shape);
  196. return new InputSpec(shape: state_spec_shape);
  197. }
  198. // Check whether the input shape contains any nested shapes. It could be
  199. // (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from
  200. // numpy inputs.
  201. if (Cell is Layer layer && !layer.Built)
  202. {
  203. layer.build(input_shape);
  204. layer.Built = true;
  205. }
  206. this.built = true;
  207. }
  208. /// <summary>
  209. ///
  210. /// </summary>
  211. /// <param name="inputs"></param>
  212. /// <param name="mask">Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked</param>
  213. /// <param name="training"></param>
  214. /// <param name="initial_state">List of initial state tensors to be passed to the first call of the cell</param>
  215. /// <param name="constants">List of constant tensors to be passed to the cell at each timestep</param>
  216. /// <returns></returns>
  217. /// <exception cref="ValueError"></exception>
  218. /// <exception cref="NotImplementedException"></exception>
  219. protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null)
  220. {
  221. RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
  222. if(optional_args is not null && rnn_optional_args is null)
  223. {
  224. throw new ArgumentException("The optional args shhould be of type `RnnOptionalArgs`");
  225. }
  226. Tensors? constants = rnn_optional_args?.Constants;
  227. Tensors? mask = rnn_optional_args?.Mask;
  228. //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs);
  229. // 暂时先不接受ragged tensor
  230. int row_length = 0; // TODO(Rinne): support this param.
  231. bool is_ragged_input = false;
  232. _validate_args_if_ragged(is_ragged_input, mask);
  233. (inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants);
  234. _maybe_reset_cell_dropout_mask(Cell);
  235. if (Cell is StackedRNNCells)
  236. {
  237. var stack_cell = Cell as StackedRNNCells;
  238. foreach (IRnnCell cell in stack_cell.Cells)
  239. {
  240. _maybe_reset_cell_dropout_mask(cell);
  241. }
  242. }
  243. if (mask != null)
  244. {
  245. // Time step masks must be the same for each input.
  246. mask = mask.Flatten().First();
  247. }
  248. Shape input_shape;
  249. if (!inputs.IsNested())
  250. {
  251. // In the case of nested input, use the first element for shape check
  252. // input_shape = nest.flatten(inputs)[0].shape;
  253. // TODO(Wanglongzhi2001)
  254. input_shape = inputs.Flatten().First().shape;
  255. }
  256. else
  257. {
  258. input_shape = inputs.shape;
  259. }
  260. var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1];
  261. if (_args.Unroll && timesteps == null)
  262. {
  263. throw new ValueError(
  264. "Cannot unroll a RNN if the " +
  265. "time dimension is undefined. \n" +
  266. "- If using a Sequential model, " +
  267. "specify the time dimension by passing " +
  268. "an `input_shape` or `batch_input_shape` " +
  269. "argument to your first layer. If your " +
  270. "first layer is an Embedding, you can " +
  271. "also use the `input_length` argument.\n" +
  272. "- If using the functional API, specify " +
  273. "the time dimension by passing a `shape` " +
  274. "or `batch_shape` argument to your Input layer."
  275. );
  276. }
  277. // cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call)
  278. Func<Tensors, Tensors, (Tensors, Tensors)> step;
  279. bool is_tf_rnn_cell = false;
  280. if (constants is not null)
  281. {
  282. if (!Cell.SupportOptionalArgs)
  283. {
  284. throw new ValueError(
  285. $"RNN cell {Cell} does not support constants." +
  286. $"Received: constants={constants}");
  287. }
  288. step = (inputs, states) =>
  289. {
  290. constants = new Tensors(states.TakeLast(_num_constants).ToArray());
  291. states = new Tensors(states.SkipLast(_num_constants).ToArray());
  292. states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
  293. var (output, new_states) = Cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
  294. return (output, new_states.Single);
  295. };
  296. }
  297. else
  298. {
  299. step = (inputs, states) =>
  300. {
  301. states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states;
  302. var (output, new_states) = Cell.Apply(inputs, states);
  303. return (output, new_states);
  304. };
  305. }
  306. var (last_output, outputs, states) = keras.backend.rnn(
  307. step,
  308. inputs,
  309. initial_state,
  310. constants: constants,
  311. go_backwards: _args.GoBackwards,
  312. mask: mask,
  313. unroll: _args.Unroll,
  314. input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps),
  315. time_major: _args.TimeMajor,
  316. zero_output_for_mask: _args.ZeroOutputForMask,
  317. return_all_outputs: _args.ReturnSequences);
  318. if (_args.Stateful)
  319. {
  320. throw new NotImplementedException("this argument havn't been developed.");
  321. }
  322. Tensors output = new Tensors();
  323. if (_args.ReturnSequences)
  324. {
  325. // TODO(Rinne): add go_backwards parameter and revise the `row_length` param
  326. output = keras.backend.maybe_convert_to_ragged(is_ragged_input, outputs, row_length, false);
  327. }
  328. else
  329. {
  330. output = last_output;
  331. }
  332. if (_args.ReturnState)
  333. {
  334. foreach (var state in states)
  335. {
  336. output.Add(state);
  337. }
  338. return output;
  339. }
  340. else
  341. {
  342. return output;
  343. }
  344. }
  345. public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool training = false, IOptionalArgs? optional_args = null)
  346. {
  347. RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
  348. if (optional_args is not null && rnn_optional_args is null)
  349. {
  350. throw new ArgumentException("The type of optional args should be `RnnOptionalArgs`.");
  351. }
  352. Tensors? constants = rnn_optional_args?.Constants;
  353. (inputs, initial_states, constants) = RnnUtils.standardize_args(inputs, initial_states, constants, _num_constants);
  354. if(initial_states is null && constants is null)
  355. {
  356. return base.Apply(inputs);
  357. }
  358. // TODO(Rinne): implement it.
  359. throw new NotImplementedException();
  360. }
  361. private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants)
  362. {
  363. if (inputs.Length > 1)
  364. {
  365. if (_num_constants != 0)
  366. {
  367. initial_state = new Tensors(inputs.Skip(1).ToArray());
  368. }
  369. else
  370. {
  371. initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants).ToArray());
  372. constants = new Tensors(inputs.TakeLast(_num_constants).ToArray());
  373. }
  374. if (len(initial_state) == 0)
  375. initial_state = null;
  376. inputs = inputs[0];
  377. }
  378. if (_args.Stateful)
  379. {
  380. if (initial_state != null)
  381. {
  382. var tmp = new Tensor[] { };
  383. foreach (var s in nest.flatten(States))
  384. {
  385. tmp.add(tf.math.count_nonzero(s.Single()));
  386. }
  387. var non_zero_count = tf.add_n(tmp);
  388. initial_state = tf.cond(non_zero_count > 0, States, initial_state);
  389. if ((int)non_zero_count.numpy() > 0)
  390. {
  391. initial_state = States;
  392. }
  393. }
  394. else
  395. {
  396. initial_state = States;
  397. }
  398. //initial_state = Nest.MapStructure(v => tf.cast(v, this.), initial_state);
  399. }
  400. else if (initial_state is null)
  401. {
  402. initial_state = get_initial_state(inputs);
  403. }
  404. if (initial_state.Length != States.Length)
  405. {
  406. throw new ValueError($"Layer {this} expects {States.Length} state(s), " +
  407. $"but it received {initial_state.Length} " +
  408. $"initial state(s). Input received: {inputs}");
  409. }
  410. return (inputs, initial_state, constants);
  411. }
  412. private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
  413. {
  414. if (!is_ragged_input)
  415. {
  416. return;
  417. }
  418. if (_args.Unroll)
  419. {
  420. throw new ValueError("The input received contains RaggedTensors and does " +
  421. "not support unrolling. Disable unrolling by passing " +
  422. "`unroll=False` in the RNN Layer constructor.");
  423. }
  424. if (mask != null)
  425. {
  426. throw new ValueError($"The mask that was passed in was {mask}, which " +
  427. "cannot be applied to RaggedTensor inputs. Please " +
  428. "make sure that there is no mask injected by upstream " +
  429. "layers.");
  430. }
  431. }
  432. void _maybe_reset_cell_dropout_mask(ILayer cell)
  433. {
  434. if (cell is DropoutRNNCellMixin CellDRCMixin)
  435. {
  436. CellDRCMixin.reset_dropout_mask();
  437. CellDRCMixin.reset_recurrent_dropout_mask();
  438. }
  439. }
  440. private static RNNArgs PreConstruct(RNNArgs args)
  441. {
  442. if (args.Kwargs == null)
  443. {
  444. args.Kwargs = new Dictionary<string, object>();
  445. }
  446. // If true, the output for masked timestep will be zeros, whereas in the
  447. // false case, output from previous timestep is returned for masked timestep.
  448. var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false);
  449. Shape input_shape;
  450. var propIS = (Shape)args.Kwargs.Get("input_shape", null);
  451. var propID = (int?)args.Kwargs.Get("input_dim", null);
  452. var propIL = (int?)args.Kwargs.Get("input_length", null);
  453. if (propIS == null && (propID != null || propIL != null))
  454. {
  455. input_shape = new Shape(
  456. propIL ?? -1,
  457. propID ?? -1);
  458. args.Kwargs["input_shape"] = input_shape;
  459. }
  460. return args;
  461. }
  462. public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null)
  463. {
  464. throw new NotImplementedException();
  465. }
  466. // 好像不能cell不能传接口类型
  467. //public RNN New(IRnnArgCell cell,
  468. // bool return_sequences = false,
  469. // bool return_state = false,
  470. // bool go_backwards = false,
  471. // bool stateful = false,
  472. // bool unroll = false,
  473. // bool time_major = false)
  474. // => new RNN(new RNNArgs
  475. // {
  476. // Cell = cell,
  477. // ReturnSequences = return_sequences,
  478. // ReturnState = return_state,
  479. // GoBackwards = go_backwards,
  480. // Stateful = stateful,
  481. // Unroll = unroll,
  482. // TimeMajor = time_major
  483. // });
  484. //public RNN New(List<IRnnArgCell> cell,
  485. // bool return_sequences = false,
  486. // bool return_state = false,
  487. // bool go_backwards = false,
  488. // bool stateful = false,
  489. // bool unroll = false,
  490. // bool time_major = false)
  491. // => new RNN(new RNNArgs
  492. // {
  493. // Cell = cell,
  494. // ReturnSequences = return_sequences,
  495. // ReturnState = return_state,
  496. // GoBackwards = go_backwards,
  497. // Stateful = stateful,
  498. // Unroll = unroll,
  499. // TimeMajor = time_major
  500. // });
  501. protected Tensors get_initial_state(Tensors inputs)
  502. {
  503. var input = inputs[0];
  504. var input_shape = array_ops.shape(inputs);
  505. var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0];
  506. var dtype = input.dtype;
  507. Tensors init_state = Cell.GetInitialState(null, batch_size, dtype);
  508. return init_state;
  509. }
  510. }
  511. }