Browse Source

return array instead of tuple for layer.call

tags/v0.12
Oceania2018 6 years ago
parent
commit
38ad490c3e
6 changed files with 20 additions and 19 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/Conv.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/Dense.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
  5. +10
    -9
      src/TensorFlowNET.Core/Keras/Layers/Layer.cs
  6. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs

+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -139,14 +139,14 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{
Tensor outputs = null;

if (fused)
{
outputs = _fused_batch_norm(inputs, training: training);
return (outputs, outputs);
return new[] { outputs, outputs };
}

throw new NotImplementedException("BatchNormalization call");


+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/Conv.cs View File

@@ -108,7 +108,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{
var outputs = _convolution_op.__call__(inputs, kernel);
if (use_bias)
@@ -126,7 +126,7 @@ namespace Tensorflow.Keras.Layers
if (activation != null)
outputs = activation.Activate(outputs);

return (outputs, outputs);
return new[] { outputs, outputs };
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/Dense.cs View File

@@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{
Tensor outputs = null;
var rank = inputs.rank;
@@ -90,7 +90,7 @@ namespace Tensorflow.Keras.Layers
if (activation != null)
outputs = activation.Activate(outputs);

return (outputs, outputs);
return new[] { outputs, outputs };
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/Embedding.cs View File

@@ -50,14 +50,14 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{
var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64)
inputs = math_ops.cast(inputs, tf.int32);

var @out = embedding_ops.embedding_lookup(embeddings, inputs);
return (@out, @out);
return new[] { @out, @out };
}
}
}

+ 10
- 9
src/TensorFlowNET.Core/Keras/Layers/Layer.cs View File

@@ -103,14 +103,14 @@ namespace Tensorflow.Keras.Layers
_inbound_nodes = new List<Node>();
}

public (Tensor, Tensor) __call__(Tensor[] inputs,
public Tensor[] __call__(Tensor[] inputs,
Tensor training = null,
Tensor state = null,
VariableScope scope = null)
{
var input_list = inputs;
var input = inputs[0];
Tensor outputs = null;
Tensor[] outputs = null;

// We will attempt to build a TF graph if & only if all inputs are symbolic.
// This is always the case in graph mode. It can also be the case in eager
@@ -142,25 +142,26 @@ namespace Tensorflow.Keras.Layers
// overridden).
_maybe_build(inputs[0]);

(input, outputs) = call(inputs[0],
outputs = call(inputs[0],
training: training,
state: state);

(input, outputs) = _set_connectivity_metadata_(input, outputs);
_handle_activity_regularization(inputs[0], outputs);
_set_mask_metadata(inputs[0], outputs, null);
});
}

return (input, outputs);
return outputs;
}

private (Tensor, Tensor) _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
private (Tensor, Tensor[]) _set_connectivity_metadata_(Tensor inputs, Tensor[] outputs)
{
//_add_inbound_node(input_tensors: inputs, output_tensors: outputs);
return (inputs, outputs);
}

private void _handle_activity_regularization(Tensor inputs, Tensor outputs)
private void _handle_activity_regularization(Tensor inputs, Tensor[] outputs)
{
//if(_activity_regularizer != null)
{
@@ -168,7 +169,7 @@ namespace Tensorflow.Keras.Layers
}
}

private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask)
private void _set_mask_metadata(Tensor inputs, Tensor[] outputs, Tensor previous_mask)
{

}
@@ -178,9 +179,9 @@ namespace Tensorflow.Keras.Layers
return null;
}

protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
protected virtual Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{
return (inputs, inputs);
throw new NotImplementedException("");
}

protected virtual string _name_scope()


+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow.Keras.Layers
this.input_spec = new InputSpec(ndim: 4);
}

protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null)
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
{
int[] pool_shape;
if (data_format == "channels_last")
@@ -64,7 +64,7 @@ namespace Tensorflow.Keras.Layers
padding: padding.ToUpper(),
data_format: conv_utils.convert_data_format(data_format, 4));

return (outputs, outputs);
return new[] { outputs, outputs };
}
}
}

Loading…
Cancel
Save