Browse Source

VariableTest #888

tags/TimeSeries
Oceania2018 4 years ago
parent
commit
e40be93380
7 changed files with 48 additions and 12 deletions
  1. +9
    -0
      src/TensorFlowNET.Core/APIs/tf.compat.v1.cs
  2. +1
    -4
      src/TensorFlowNET.Core/APIs/tf.variable.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  4. +6
    -0
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  5. +1
    -3
      src/TensorFlowNET.Core/tensorflow.cs
  6. +2
    -2
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  7. +26
    -0
      test/TensorFlowNET.Graph.UnitTest/Basics/VariableTest.cs

+ 9
- 0
src/TensorFlowNET.Core/APIs/tf.compat.v1.cs View File

@@ -47,5 +47,14 @@ namespace Tensorflow
trainable: trainable, trainable: trainable,
collections: collections); collections: collections);
} }

public Operation global_variables_initializer()
{
var g = variables.global_variables();
return variables.variables_initializer(g.ToArray());
}

public Session Session()
=> new Session().as_default();
} }
} }

+ 1
- 4
src/TensorFlowNET.Core/APIs/tf.variable.cs View File

@@ -37,10 +37,7 @@ namespace Tensorflow
=> variables.variables_initializer(var_list, name: name); => variables.variables_initializer(var_list, name: name);


public Operation global_variables_initializer() public Operation global_variables_initializer()
{
var g = variables.global_variables();
return variables.variables_initializer(g.ToArray());
}
=> tf.compat.v1.global_variables_initializer();


/// <summary> /// <summary>
/// Returns all variables created with `trainable=True`. /// Returns all variables created with `trainable=True`.


+ 3
- 3
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName> <AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace> <RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>2.2.0</TargetTensorFlow> <TargetTensorFlow>2.2.0</TargetTensorFlow>
<Version>0.60.5</Version>
<Version>0.60.6</Version>
<LangVersion>9.0</LangVersion> <LangVersion>9.0</LangVersion>
<Nullable>enable</Nullable> <Nullable>enable</Nullable>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
@@ -20,7 +20,7 @@
<Description>Google's TensorFlow full binding in .NET Standard. <Description>Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models. Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description> https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.60.5.0</AssemblyVersion>
<AssemblyVersion>0.60.6.0</AssemblyVersion>
<PackageReleaseNotes>tf.net 0.60.x and above are based on tensorflow native 2.6.0 <PackageReleaseNotes>tf.net 0.60.x and above are based on tensorflow native 2.6.0


* Eager Mode is added finally. * Eager Mode is added finally.
@@ -35,7 +35,7 @@ Keras API is a separate package released as TensorFlow.Keras.
tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library. tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.
tf.net 0.5x.x aligns with TensorFlow v2.5.x native library. tf.net 0.5x.x aligns with TensorFlow v2.5.x native library.
tf.net 0.6x.x aligns with TensorFlow v2.6.x native library.</PackageReleaseNotes> tf.net 0.6x.x aligns with TensorFlow v2.6.x native library.</PackageReleaseNotes>
<FileVersion>0.60.5.0</FileVersion>
<FileVersion>0.60.6.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly> <SignAssembly>true</SignAssembly>


+ 6
- 0
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -17,6 +17,7 @@
using Google.Protobuf; using Google.Protobuf;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using Tensorflow.NumPy;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
@@ -229,5 +230,10 @@ namespace Tensorflow


throw new NotImplementedException("to_proto RefVariable"); throw new NotImplementedException("to_proto RefVariable");
} }

public NDArray eval(Session session = null)
{
return _graph_element.eval(session);
}
} }
} }

+ 1
- 3
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -93,9 +93,7 @@ namespace Tensorflow
=> ops.get_default_session(); => ops.get_default_session();


public Session Session() public Session Session()
{
return new Session().as_default();
}
=> compat.v1.Session();


public Session Session(Graph graph, ConfigProto config = null) public Session Session(Graph graph, ConfigProto config = null)
{ {


+ 2
- 2
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -37,8 +37,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<RepositoryType>Git</RepositoryType> <RepositoryType>Git</RepositoryType>
<SignAssembly>true</SignAssembly> <SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
<AssemblyVersion>0.6.5.0</AssemblyVersion>
<FileVersion>0.6.5.0</FileVersion>
<AssemblyVersion>0.6.6.0</AssemblyVersion>
<FileVersion>0.6.6.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>
</PropertyGroup> </PropertyGroup>




+ 26
- 0
test/TensorFlowNET.Graph.UnitTest/Basics/VariableTest.cs View File

@@ -0,0 +1,26 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.Basics
{
[TestClass]
public class VariableTest : GraphModeTestBase
{
[TestMethod]
public void InitVariable()
{
var v = tf.Variable(new[] { 1, 2 });
var init = tf.compat.v1.global_variables_initializer();

using var sess = tf.compat.v1.Session();
sess.run(init);
// Usage passing the session explicitly.
print(v.eval(sess));
// Usage with the default session. The 'with' block
// above makes 'sess' the default session.
print(v.eval());
}
}
}

Loading…
Cancel
Save