|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- using System;
- using System.Linq;
- using Tensorflow.Common.Types;
- using Tensorflow.Framework;
- using Tensorflow.Keras.ArgsDefinition;
- using Tensorflow.Keras.Engine;
- using Tensorflow.Keras.Utils;
- using static Tensorflow.Binding;
-
- namespace Tensorflow.Keras.Layers
- {
- public class Flatten : Layer
- {
- FlattenArgs args;
- InputSpec input_spec;
- bool _channels_first;
-
- public Flatten(FlattenArgs args)
- : base(args)
- {
- this.args = args;
- args.DataFormat = conv_utils.normalize_data_format(args.DataFormat);
- input_spec = new InputSpec(min_ndim: 1);
- _channels_first = args.DataFormat == "channels_first";
- }
-
- protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
- {
- if (_channels_first)
- {
- throw new NotImplementedException("");
- }
-
- if (tf.executing_eagerly())
- {
- return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 });
- }
- else
- {
- var input_shape = inputs.shape;
- var rank = inputs.shape.ndim;
- if (rank == 1)
- return array_ops.expand_dims(inputs, axis: 1);
- var batch_dim = tensor_shape.dimension_value(input_shape[0]);
- if (batch_dim != -1)
- {
- return array_ops.reshape(inputs, new[] { batch_dim, -1 });
- }
-
- var non_batch_dims = ((int[])input_shape).Skip(1).ToArray();
- var num = 1;
- if (non_batch_dims.Length > 0)
- {
- for (var i = 0; i < non_batch_dims.Length; i++)
- {
- num *= non_batch_dims[i];
- }
- }
- return array_ops.reshape(inputs, new[] { inputs.shape[0], num });
- }
- }
- }
- }
|