Browse Source

rename call to call_fn.

tags/v0.30
Oceania2018 5 years ago
parent
commit
e92aa44c1d
23 changed files with 38 additions and 34 deletions
  1. +3
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorFlowOpLayerArgs.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Keras/BackendImpl.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Flatten.cs
  4. +6
    -2
      src/TensorFlowNET.Core/Keras/Engine/Functional.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  8. +0
    -2
      src/TensorFlowNET.Core/Keras/Engine/Node.cs
  9. +2
    -2
      src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs
  12. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Dense.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Dropout.cs
  14. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
  15. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/LSTM.cs
  16. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
  17. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs
  18. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs
  19. +4
    -3
      src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
  20. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  21. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
  22. +3
    -0
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  23. +2
    -6
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj

+ 3
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorFlowOpLayerArgs.cs View File

@@ -1,4 +1,5 @@
using System;
using NumSharp;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


@@ -7,5 +8,6 @@ namespace Tensorflow.Keras.ArgsDefinition
public class TensorFlowOpLayerArgs : LayerArgs public class TensorFlowOpLayerArgs : LayerArgs
{ {
public NodeDef NodeDef { get; set; } public NodeDef NodeDef { get; set; }
public Dictionary<int, NDArray> Constants { get; set; }
} }
} }

+ 2
- 2
src/TensorFlowNET.Core/Keras/BackendImpl.cs View File

@@ -160,9 +160,9 @@ namespace Tensorflow.Keras
/// </summary> /// </summary>
/// <param name="outputs"></param> /// <param name="outputs"></param>
/// <returns></returns> /// <returns></returns>
public Tensor eval_in_eager_or_function(Tensor outputs)
public NDArray eval_in_eager_or_function(Tensor outputs)
{ {
throw new NotImplementedException("");
return outputs.eval();
} }


public class _DummyEagerGraph public class _DummyEagerGraph


+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Flatten.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine
_channels_first = args.DataFormat == "channels_first"; _channels_first = args.DataFormat == "channels_first";
} }


protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{ {
if (_channels_first) if (_channels_first)
{ {


+ 6
- 2
src/TensorFlowNET.Core/Keras/Engine/Functional.cs View File

@@ -69,10 +69,14 @@ namespace Tensorflow.Keras.Engine
} }
} }


protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{ {
return base.call(inputs, state, is_training);
return run_internal_graph(inputs, state, is_training);
} }


Tensors run_internal_graph(Tensors inputs, Tensor state = null, bool is_training = false)
{
throw new NotImplementedException("");
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs View File

@@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine
if (!built) if (!built)
MaybeBuild(inputs); MaybeBuild(inputs);


outputs = call(inputs, state: state, is_training: is_training);
outputs = call_fn(inputs, state: state, is_training: is_training);


outputs = _set_connectivity_metadata_(inputs, outputs); outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs); _handle_activity_regularization(inputs, outputs);


+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Engine
if (!dynamic) if (!dynamic)
throw new NotImplementedException(""); throw new NotImplementedException("");


outputs = call(inputs);
outputs = call_fn(inputs);


outputs = _set_connectivity_metadata_(inputs, outputs); outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs); _handle_activity_regularization(inputs, outputs);


+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -162,7 +162,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="state"></param> /// <param name="state"></param>
/// <param name="is_training"></param> /// <param name="is_training"></param>
/// <returns></returns> /// <returns></returns>
protected virtual Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected virtual Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{ {
throw new NotImplementedException(""); throw new NotImplementedException("");
} }


+ 0
- 2
src/TensorFlowNET.Core/Keras/Engine/Node.cs View File

@@ -52,8 +52,6 @@ namespace Tensorflow.Keras.Engine
layer.InboundNodes.Add(this); layer.InboundNodes.Add(this);
foreach (var kt in kerasInputs) foreach (var kt in kerasInputs)
{ {
if (kt.KerasHistory == null)
continue;
var inbound_layer = kt.KerasHistory.layer; var inbound_layer = kt.KerasHistory.layer;
if (inbound_layer != null) if (inbound_layer != null)
inbound_layer.OutboundNodes.Add(this); inbound_layer.OutboundNodes.Add(this);


+ 2
- 2
src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs View File

@@ -23,9 +23,9 @@ namespace Tensorflow.Keras.Engine
built = true; built = true;
} }


protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{ {
return base.call(inputs, state, is_training);
return base.call_fn(inputs, state, is_training);
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -119,7 +119,7 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{ {
Tensor outputs = null; Tensor outputs = null;




+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs View File

@@ -98,7 +98,7 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool training = false)
{ {
var outputs = _convolution_op.Apply(inputs, kernel); var outputs = _convolution_op.Apply(inputs, kernel);
if (use_bias) if (use_bias)


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Dense.cs View File

@@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool training = false)
{ {
Tensor outputs = null; Tensor outputs = null;
var rank = inputs.rank; var rank = inputs.rank;


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Dropout.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers
this.args = args; this.args = args;
} }


protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{ {
var output = tf_utils.smart_cond(is_training, var output = tf_utils.smart_cond(is_training,
() => tf.nn.dropout(inputs, () => tf.nn.dropout(inputs,


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Embedding.cs View File

@@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{ {
var dtype = inputs.dtype; var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64) if (dtype != tf.int32 && dtype != tf.int64)


+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/LSTM.cs View File

@@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers
.ToArray(); .ToArray();
} }


protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{ {
return base.call(inputs, state: state, is_training: is_training);
return base.call_fn(inputs, state: state, is_training: is_training);
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs View File

@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers
input_spec = new InputSpec(ndim: 4); input_spec = new InputSpec(ndim: 4);
} }


protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{ {
int[] pool_shape; int[] pool_shape;
int[] strides; int[] strides;


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs View File

@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers
this.args = args; this.args = args;
} }


protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{ {
scale = math_ops.cast(args.Scale, args.DType); scale = math_ops.cast(args.Scale, args.DType);
offset = math_ops.cast(args.Offset, args.DType); offset = math_ops.cast(args.Offset, args.DType);


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Layers
this.input_spec = new InputSpec(ndim: 4); this.input_spec = new InputSpec(ndim: 4);
} }


protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{ {
return tf.keras.backend.spatial_2d_padding(inputs, return tf.keras.backend.spatial_2d_padding(inputs,
padding: padding, padding: padding,


+ 4
- 3
src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs View File

@@ -14,6 +14,7 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using NumSharp;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
@@ -135,7 +136,7 @@ namespace Tensorflow.Keras.Utils
if (!processed_ops.Contains(op)) if (!processed_ops.Contains(op))
{ {
var layer_inputs = new List<Tensor>(); var layer_inputs = new List<Tensor>();
var constants = new Dictionary<int, NDArray>();
foreach (var (i, op_input) in enumerate(op.inputs._inputs)) foreach (var (i, op_input) in enumerate(op.inputs._inputs))
{ {
if (uses_keras_history(op_input)) if (uses_keras_history(op_input))
@@ -144,8 +145,7 @@ namespace Tensorflow.Keras.Utils
{ {
tf_with(ops.init_scope(), delegate tf_with(ops.init_scope(), delegate
{ {


constants[i] = tf.keras.backend.eval_in_eager_or_function(op_input);
}); });
} }
} }
@@ -155,6 +155,7 @@ namespace Tensorflow.Keras.Utils
var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs
{ {
NodeDef = op.node_def, NodeDef = op.node_def,
Constants = constants,
Name = op.name Name = op.name
}); });
created_layers.Add(op_layer); created_layers.Add(op_layer);


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -74,7 +74,7 @@ namespace Tensorflow
/// <param name="training"></param> /// <param name="training"></param>
/// <param name="state"></param> /// <param name="state"></param>
/// <returns></returns> /// <returns></returns>
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{ {
var one = constant_op.constant(1, dtype: dtypes.int32); var one = constant_op.constant(1, dtype: dtypes.int32);
// Parameters of gates are concatenated into one multiply for efficiency. // Parameters of gates are concatenated into one multiply for efficiency.


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs View File

@@ -67,7 +67,7 @@ namespace Tensorflow
built = true; built = true;
} }


protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{ {
// Most basic RNN: output = new_state = act(W * input + U * state + B). // Most basic RNN: output = new_state = act(W * input + U * state + B).
var concat = array_ops.concat(new Tensor[] { inputs, state }, 1); var concat = array_ops.concat(new Tensor[] { inputs, state }, 1);


+ 3
- 0
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -124,6 +124,9 @@ namespace Tensorflow
case NPTypeCode.Double: case NPTypeCode.Double:
full_values.Add(value.GetValue<double>(0)); full_values.Add(value.GetValue<double>(0));
break; break;
case NPTypeCode.Boolean:
full_values.Add(value.GetValue<bool>(0));
break;
/*case "String": /*case "String":
full_values.Add(value.Data<byte>()[0]); full_values.Add(value.Data<byte>()[0]);
break;*/ break;*/


+ 2
- 6
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -72,6 +72,8 @@ https://tensorflownet.readthedocs.io</Description>
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>
<None Remove="FodyWeavers.xml" />
<None Remove="FodyWeavers.xsd" />
<None Remove="Protobuf\README.md" /> <None Remove="Protobuf\README.md" />
</ItemGroup> </ItemGroup>


@@ -84,10 +86,4 @@ https://tensorflownet.readthedocs.io</Description>
<ItemGroup> <ItemGroup>
<Folder Include="Keras\Initializers\" /> <Folder Include="Keras\Initializers\" />
</ItemGroup> </ItemGroup>

<ItemGroup>
<None Update="FodyWeavers.xml">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>
</Project> </Project>

Loading…
Cancel
Save