|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- /*****************************************************************************
- Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ******************************************************************************/
-
- using System;
- using System.Collections.Generic;
- using System.Text;
- using Tensorflow.Keras.Engine;
- using Tensorflow.Operations.Activation;
-
- namespace Tensorflow
- {
- public class BasicRNNCell : LayerRNNCell
- {
- int _num_units;
- Func<Tensor, string, Tensor> _activation;
-
- protected override int state_size => _num_units;
-
- public BasicRNNCell(int num_units,
- Func<Tensor, string, Tensor> activation = null,
- bool? reuse = null,
- string name = null,
- TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse,
- name: name,
- dtype: dtype)
- {
- // Inputs must be 2-dimensional.
- input_spec = new InputSpec(ndim: 2);
-
- _num_units = num_units;
- if (activation == null)
- _activation = math_ops.tanh;
- else
- _activation = activation;
- }
- }
- }
|