Browse Source

CheckInputFromValidContext

tags/v0.12
Oceania2018 6 years ago
parent
commit
a65d881213
11 changed files with 117 additions and 33 deletions
  1. +5
    -4
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  2. +2
    -3
      src/TensorFlowNET.Core/Operations/Operation.Control.cs
  3. +1
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +7
    -3
      src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
  5. +36
    -16
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  6. +22
    -0
      src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
  7. +2
    -0
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  8. +33
    -0
      src/TensorFlowNET.Core/Operations/tensor_array_ops.cs
  9. +2
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  10. +2
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  11. +5
    -4
      src/TensorFlowNET.Core/Util/nest.py.cs

+ 5
- 4
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -172,7 +172,8 @@ namespace Tensorflow.Operations


for (int i = 0; i < input_ta.Count; i++) for (int i = 0; i < input_ta.Count; i++)
{ {
var (ta, input_) = (input_ta[0], flat_input[0]);
var (ta, input_) = (input_ta[i], flat_input[i]);
ta.unstack(input_);
} }
} }


@@ -185,16 +186,16 @@ namespace Tensorflow.Operations


Func<BodyItemInRnnWhileLoop, Tensor> cond = (item) => Func<BodyItemInRnnWhileLoop, Tensor> cond = (item) =>
{ {
return time < loop_bound;
return item.time < loop_bound;
}; };


// Take a time step of the dynamic RNN. // Take a time step of the dynamic RNN.
Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) => Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) =>
{ {
return item;
throw new NotImplementedException("");
}; };


control_flow_ops.while_loop<BodyItemInRnnWhileLoop>(
control_flow_ops.while_loop(
cond: cond, cond: cond,
body: _time_step, body: _time_step,
loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state), loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state),


+ 2
- 3
src/TensorFlowNET.Core/Operations/Operation.Control.cs View File

@@ -30,10 +30,9 @@ namespace Tensorflow
/// </summary> /// </summary>
public void _control_flow_post_processing() public void _control_flow_post_processing()
{ {
foreach(var input_tensor in inputs)
foreach(Tensor input_tensor in inputs)
{ {
//TODO: implement below code dependency
//control_flow_util.CheckInputFromValidContext(this, input_tensor.op);
control_flow_util.CheckInputFromValidContext(this, input_tensor.op);
} }
if (_control_flow_context != null) if (_control_flow_context != null)


+ 1
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -23,6 +23,7 @@ using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {


+ 7
- 3
src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs View File

@@ -25,6 +25,7 @@ namespace Tensorflow.Operations
internal class _GraphTensorArray internal class _GraphTensorArray
{ {
internal TF_DataType _dtype; internal TF_DataType _dtype;
public TF_DataType dtype => _dtype;


/// <summary> /// <summary>
/// Used to keep track of what tensors the TensorArray should be /// Used to keep track of what tensors the TensorArray should be
@@ -32,14 +33,17 @@ namespace Tensorflow.Operations
/// first tensor written to it. /// first tensor written to it.
/// </summary> /// </summary>
bool _colocate_with_first_write_call; bool _colocate_with_first_write_call;
public bool colocate_with_first_write_call => _colocate_with_first_write_call;


bool _infer_shape; bool _infer_shape;
bool _dynamic_size;
List<TensorShape> _element_shape;
public bool infer_shape => _infer_shape;
public bool _dynamic_size;
public List<TensorShape> _element_shape;


List<Tensor> _colocate_with;
public List<Tensor> _colocate_with;


internal Tensor _handle; internal Tensor _handle;
public Tensor handle => _handle;
internal Tensor _flow; internal Tensor _flow;


public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,


+ 36
- 16
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -21,6 +21,7 @@ using Tensorflow.Operations;
using Tensorflow.Operations.ControlFlows; using Tensorflow.Operations.ControlFlows;
using util = Tensorflow.control_flow_util; using util = Tensorflow.control_flow_util;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using Tensorflow.Util;


namespace Tensorflow namespace Tensorflow
{ {
@@ -251,12 +252,16 @@ namespace Tensorflow
return gen_array_ops.identity(data, name: name); return gen_array_ops.identity(data, name: name);
} }


public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape shapes = null)
public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape[] shapes = null)
{ {
if (shapes == null) if (shapes == null)
return; return;


throw new NotImplementedException("_SetShapeInvariants");
var flat_shapes = nest.flatten2(shapes);
foreach (var (inp, var, shape) in zip(input_vars, enter_vars, flat_shapes))
{
var.set_shape(shape);
}
} }


/// <summary> /// <summary>
@@ -428,12 +433,12 @@ namespace Tensorflow
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
.ToArray(); .ToArray();


merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges);
var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges);


ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);


return merges[0];
return new Tensor(IntPtr.Zero);
}); });
} }


@@ -473,22 +478,28 @@ namespace Tensorflow
var res_f_flat = res_f; var res_f_flat = res_f;


var merges = zip(res_f_flat, res_t_flat) var merges = zip(res_f_flat, res_t_flat)
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
.Select(pair => merge(new [] { pair.Item1, pair.Item2 }))
.ToArray(); .ToArray();


merges = _convert_flows_to_tensorarrays(orig_res_t, merges);
var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges);


ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);


return merges;
return new[] { new Tensor(IntPtr.Zero) };
}); });
} }


public static Tensor[] _convert_flows_to_tensorarrays<T>(T tensors_or_tensorarrays, Tensor[] tensors_or_flows)
public static ITensorOrTensorArray[] _convert_flows_to_tensorarrays(ITensorOrTensorArray[] tensors_or_tensorarrays, Tensor[] tensors_or_flows)
{ {
// zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray();
return tensors_or_flows;
return zip(tensors_or_tensorarrays, tensors_or_flows).Select(x =>
{
var (ta, t_or_flow) = (x.Item1, x.Item2);
if (ta is TensorArray ta_1)
return tensor_array_ops.build_ta_with_new_flow(ta_1, t_or_flow) as ITensorOrTensorArray;
else
return t_or_flow as ITensorOrTensorArray;
}).ToArray();
} }


/// <summary> /// <summary>
@@ -592,7 +603,7 @@ namespace Tensorflow
/// <param name="loop_vars"></param> /// <param name="loop_vars"></param>
/// <param name="i"></param> /// <param name="i"></param>
public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TItem> body, TItem loop_vars, public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TItem> body, TItem loop_vars,
TensorShape shape_invariants = null,
TensorShape[] shape_invariants = null,
int parallel_iterations = 10, int parallel_iterations = 10,
bool back_prop = true, bool back_prop = true,
bool swap_memory = false, bool swap_memory = false,
@@ -617,8 +628,8 @@ namespace Tensorflow
var orig_body = body; var orig_body = body;


LoopVar<TItem> loop_vars_1 = null; LoopVar<TItem> loop_vars_1 = null;
Func<Tensor, TItem, LoopVar<TItem>> body_buildloop = null;
Func<Tensor, TItem, Tensor> cond_buildloop = null;
Func<LoopVar<TItem>, LoopVar<TItem>> body_buildloop = null;
Func<LoopVar<TItem>, Tensor> cond_buildloop = null;


if (try_to_pack) if (try_to_pack)
{ {
@@ -627,9 +638,18 @@ namespace Tensorflow
else else
{ {
loop_vars_1 = new LoopVar<TItem>(counter, loop_vars); loop_vars_1 = new LoopVar<TItem>(counter, loop_vars);
cond_buildloop = (i, lv) =>
math_ops.logical_and(i < maximum_iterations, orig_cond(lv));
body_buildloop = (i, lv) => new LoopVar<TItem>(i + 1, orig_body(lv));
cond_buildloop = (item) =>
{
var (i, lv) = (item.Counter, item.Item);
var oc = orig_cond(lv);
return math_ops.logical_and(i < maximum_iterations, oc);
};

body_buildloop = (item) =>
{
var (i, lv) = (item.Counter, item.Item);
return new LoopVar<TItem>(i + 1, orig_body(lv));
};
} }
try_to_pack = false; try_to_pack = false;




+ 22
- 0
src/TensorFlowNET.Core/Operations/control_flow_util.py.cs View File

@@ -14,7 +14,9 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System;
using Tensorflow.Operations; using Tensorflow.Operations;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
@@ -53,5 +55,25 @@ namespace Tensorflow
ctxt = ctxt.outer_context; ctxt = ctxt.outer_context;
return ctxt; return ctxt;
} }

public static void CheckInputFromValidContext(Operation op, Operation input_op)
{
var op_ctxt = op._get_control_flow_context();
var input_ctxt = GetOutputContext(input_op);
var valid = false;
if (input_ctxt == null)
valid = true;
else if (op_ctxt == input_ctxt)
valid = true;
else
{
throw new NotImplementedException("");
}
if (!valid)
{
throw new NotImplementedException("");
}
}
} }
} }

+ 2
- 0
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/
using static Tensorflow.Binding;
namespace Tensorflow namespace Tensorflow
{ {
public static class gen_math_ops public static class gen_math_ops


+ 33
- 0
src/TensorFlowNET.Core/Operations/tensor_array_ops.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class tensor_array_ops
{
/// <summary>
/// Builds a TensorArray with a new `flow` tensor.
/// </summary>
/// <param name="old_ta"></param>
/// <param name="flow"></param>
/// <returns></returns>
public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow)
{
var impl = old_ta._implementation;

var new_ta = new TensorArray(
dtype: impl.dtype,
handle: impl.handle,
flow: flow,
infer_shape: impl.infer_shape,
colocate_with_first_write_call: impl.colocate_with_first_write_call);

var new_impl = new_ta._implementation;
new_impl._dynamic_size = impl._dynamic_size;
new_impl._colocate_with = impl._colocate_with;
new_impl._element_shape = impl._element_shape;
return new_ta;
}
}
}

+ 2
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -20,7 +20,8 @@ Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description> https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.12.0.0</AssemblyVersion> <AssemblyVersion>0.12.0.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.11.0: <PackageReleaseNotes>Changes since v0.11.0:
</PackageReleaseNotes>
1: Add ICanBeFlattened for nest.flatten2.
2:</PackageReleaseNotes>
<LangVersion>7.3</LangVersion> <LangVersion>7.3</LangVersion>
<FileVersion>0.12.0.0</FileVersion> <FileVersion>0.12.0.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>


+ 2
- 2
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -39,7 +39,7 @@ namespace Tensorflow
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
/// </summary> /// </summary>
[SuppressMessage("ReSharper", "ConvertToAutoProperty")] [SuppressMessage("ReSharper", "ConvertToAutoProperty")]
public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike
public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike, ITensorOrTensorArray
{ {
private readonly int _id; private readonly int _id;
private readonly Operation _op; private readonly Operation _op;
@@ -178,7 +178,7 @@ namespace Tensorflow
/// </summary> /// </summary>
public void set_shape(TensorShape shape) public void set_shape(TensorShape shape)
{ {
this.shape = shape.rank > 0 ? shape.dims : null;
this.shape = shape.rank >= 0 ? shape.dims : null;
} }


/// <summary> /// <summary>


+ 5
- 4
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -223,9 +223,10 @@ namespace Tensorflow.Util
} }
public static object[] flatten2(ICanBeFlattened structure) public static object[] flatten2(ICanBeFlattened structure)
{
return structure.Flatten();
}
=> structure.Flatten();
public static T[] flatten2<T>(T[] structure)
=> structure;
private static void _flatten_recursive<T>(T obj, List<T> list) private static void _flatten_recursive<T>(T obj, List<T> list)
{ {
@@ -423,7 +424,7 @@ namespace Tensorflow.Util
/// <returns> `flat_sequence` converted to have the same recursive structure as /// <returns> `flat_sequence` converted to have the same recursive structure as
/// `structure`. /// `structure`.
/// </returns> /// </returns>
public static object pack_sequence_as(object structure, IEnumerable<object> flat_sequence)
public static object pack_sequence_as(object structure, IEnumerable<object> flat_sequence, bool expand_composites = false)
{ {
List<object> flat = null; List<object> flat = null;
if (flat_sequence is List<object>) if (flat_sequence is List<object>)


Loading…
Cancel
Save