Browse Source

tf.while_loop, add ICanBeFlattened #348

tags/v0.12
Oceania2018 6 years ago
parent
commit
547c4e6bf4
7 changed files with 48 additions and 32 deletions
  1. +10
    -7
      src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs
  2. +4
    -4
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
  3. +11
    -0
      src/TensorFlowNET.Core/Operations/IFlatten.cs
  4. +9
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  6. +7
    -19
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  7. +6
    -0
      src/TensorFlowNET.Core/Util/nest.py.cs

+ 10
- 7
src/TensorFlowNET.Core/Operations/ControlFlows/LoopVar.cs View File

@@ -4,22 +4,25 @@ using System.Text;


namespace Tensorflow.Operations namespace Tensorflow.Operations
{ {
internal class LoopVar<TItem>
internal class LoopVar<TItem> : ICanBeFlattened
{ {
public Tensor Counter { get; } public Tensor Counter { get; }
public TItem[] Items { get; }
public TItem Item { get; } public TItem Item { get; }


public LoopVar(Tensor counter, TItem[] items)
public LoopVar(Tensor counter, TItem item)
{ {
Counter = counter; Counter = counter;
Items = items;
Item = item;
} }


public LoopVar(Tensor counter, TItem item)
public object[] Flatten()
{ {
Counter = counter;
Item = item;
var elements = new List<object> { Counter };
if (typeof(TItem).GetInterface(typeof(ICanBeFlattened).Name) != null)
elements.AddRange((Item as ICanBeFlattened).Flatten());
else
elements.Add(Item);
return elements.ToArray();
} }
} }
} }

+ 4
- 4
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -109,7 +109,7 @@ namespace Tensorflow.Operations
/// </summary> /// </summary>
internal Tensor[] BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred, internal Tensor[] BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred,
Func<Tensor, TItem, LoopVar<TItem>> body, Func<Tensor, TItem, LoopVar<TItem>> body,
TItem loop_vars,
LoopVar<TItem> loop_vars,
TensorShape shape_invariants, TensorShape shape_invariants,
bool return_same_structure) bool return_same_structure)
{ {
@@ -143,8 +143,8 @@ namespace Tensorflow.Operations


private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred, private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred,
Func<Tensor, TItem, LoopVar<TItem>> body, Func<Tensor, TItem, LoopVar<TItem>> body,
TItem original_loop_vars,
TItem loop_vars,
LoopVar<TItem> original_loop_vars,
LoopVar<TItem> loop_vars,
TensorShape shape_invariants) TensorShape shape_invariants)
{ {
var flat_loop_vars = original_loop_vars; var flat_loop_vars = original_loop_vars;
@@ -152,7 +152,7 @@ namespace Tensorflow.Operations
// Convert TensorArrays to their flow variables // Convert TensorArrays to their flow variables
var loop_vars_tensor = nest.map_structure( var loop_vars_tensor = nest.map_structure(
_convert_tensorarray_to_flow, _convert_tensorarray_to_flow,
nest.flatten(loop_vars));
nest.flatten2(loop_vars));


// Let the context know the loop variables so the loop variables // Let the context know the loop variables so the loop variables
// would be added in the outer contexts properly. // would be added in the outer contexts properly.


+ 11
- 0
src/TensorFlowNET.Core/Operations/IFlatten.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
public interface ICanBeFlattened
{
object[] Flatten();
}
}

+ 9
- 1
src/TensorFlowNET.Core/Operations/NnOps/BodyItemInRnnWhileLoop.cs View File

@@ -4,7 +4,7 @@ using System.Text;


namespace Tensorflow.Operations namespace Tensorflow.Operations
{ {
internal class BodyItemInRnnWhileLoop
internal class BodyItemInRnnWhileLoop : ICanBeFlattened
{ {
/// <summary> /// <summary>
/// int32 scalar Tensor. /// int32 scalar Tensor.
@@ -28,5 +28,13 @@ namespace Tensorflow.Operations


public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item) public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item)
=> (item.time, item.output_ta_t, item.state); => (item.time, item.output_ta_t, item.state);

public object[] Flatten()
{
var elements = new List<object> { time };
elements.AddRange(output_ta_t);
elements.Add(state);
return elements.ToArray();
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -642,7 +642,7 @@ namespace Tensorflow
if (loop_context.outer_context == null) if (loop_context.outer_context == null)
ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context);


var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars, shape_invariants,
var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars_1, shape_invariants,
return_same_structure); return_same_structure);


if (maximum_iterations != null) if (maximum_iterations != null)


+ 7
- 19
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName> <AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace> <RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>1.14.0</TargetTensorFlow> <TargetTensorFlow>1.14.0</TargetTensorFlow>
<Version>0.11.8.1</Version>
<Version>0.12.0</Version>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company> <Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> <GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -16,25 +16,13 @@
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl> <PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl>
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags>
<Description>Google's TensorFlow full binding in .NET Standard. <Description>Google's TensorFlow full binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.11.8.1</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.10.0:
1. Upgrade NumSharp to v0.20.3.
2. Add DisposableObject class to manage object lifetime.
3. Add tf.no_op, tf.nn.in_top_k, tf.GraphKeys and tf.trainable_variables.
4. Change tensorflow to non-static class in order to execute some initialization process.
5. Overload session.run(), make syntax simpler.
6. Add Local Response Normalization.
7. Add tf.image related APIs.
8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor.
9. MultiThread is safe.
10. Support n-dim indexing for tensor.
11. Add RegisterNoGradients
12. Add CumsumGrad, BroadcastToGrad.
13. Return VariableV1 instead of RefVariable.
14. Add Tensor overload to GradientDescentOptimizer.</PackageReleaseNotes>
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.12.0.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.11.0:
</PackageReleaseNotes>
<LangVersion>7.3</LangVersion> <LangVersion>7.3</LangVersion>
<FileVersion>0.11.8.1</FileVersion>
<FileVersion>0.12.0.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly> <SignAssembly>true</SignAssembly>


+ 6
- 0
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -19,6 +19,7 @@ using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using NumSharp; using NumSharp;
using Tensorflow.Operations;
namespace Tensorflow.Util namespace Tensorflow.Util
{ {
@@ -221,6 +222,11 @@ namespace Tensorflow.Util
return list; return list;
} }
public static object[] flatten2(ICanBeFlattened structure)
{
return structure.Flatten();
}
private static void _flatten_recursive<T>(T obj, List<T> list) private static void _flatten_recursive<T>(T obj, List<T> list)
{ {
switch(obj) switch(obj)


Loading…
Cancel
Save