Browse Source

tf2.0

pull/369/head
arnavdas88 6 years ago
parent
commit
bed0013749
4 changed files with 194 additions and 2 deletions
  1. +2
    -1
      src/TensorFlowNET.Core/Keras/Layers/Layer.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Layers/Layer.cs
  3. +142
    -0
      src/TensorFlowNET.Core/Module/Module.cs
  4. +49
    -0
      src/TensorFlowNET.Core/tf2.cs

+ 2
- 1
src/TensorFlowNET.Core/Keras/Layers/Layer.cs View File

@@ -20,6 +20,7 @@ using System.Linq;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;
using Tensorflow.Module;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers
@@ -32,7 +33,7 @@ namespace Tensorflow.Keras.Layers
///
/// tensorflow\python\keras\engine\base_layer.py
/// </summary>
public class Layer : AutoTrackable
public class Layer : Module.Module
{
/// <summary>
/// Indicates whether `build` needs to be called upon layer call, to create


+ 1
- 1
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -128,7 +128,7 @@ namespace Tensorflow.Layers
else
{
init_graph = default_graph;
existing_variables = variables.global_variables().ToArray();
existing_variables = Tensorflow.variables.global_variables().ToArray();
}

if(dtype == TF_DataType.DtInvalid)


+ 142
- 0
src/TensorFlowNET.Core/Module/Module.cs View File

@@ -0,0 +1,142 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using Tensorflow;
using Tensorflow.Train;
using static Tensorflow.Binding;

namespace Tensorflow.Module
{
/// <summary>
/// Base neural network module class.
/// A module is a named container for `tf.Variable`s, other `tf.Module`s and
/// functions which apply to user input. For example a dense layer in a neural
/// network might be implemented as a `tf.Module`:
///
/// tensorflow/python/module/module.py
/// </summary>
public class Module : AutoTrackable
{
SortedSet<string> _TF_MODULE_IGNORED_PROPERTIES = new SortedSet<string>(){ "_self_unconditional_checkpoint_dependencies", "_self_unconditional_dependency_names"};

protected string _name;
/// <summary>
/// Returns the name of this module as passed or determined in the ctor.
/// NOTE: This is not the same as the `self.name_scope.name` which includes
/// parent module names.
/// </summary>
public string name => _name;

protected ops.NameScope _name_scope;
protected ops.NameScope _scope_name;
/// <summary>
/// Returns a `tf.name_scope` instance for this class.
/// </summary>
public ops.NameScope name_scope
{
get{
if(tf2.enabled())
return this._name_scope;
return ops.name_scope(_scope_name);
}
}
/// <summary>
/// Sequence of variables owned by this module and it's submodules.
/// Note: this method uses reflection to find variables on the current instance
/// and submodules. For performance reasons you may wish to cache the result
/// of calling this method if you don't expect the return value to change.
/// </summary>
/// <returns>
/// A sequence of variables for the current module (sorted by attribute
/// name) followed by variables from all submodules recursively (breadth
/// first).
/// </returns>
public ValueTuple variables => throw new NotImplementedException();

/// <summary>
/// Sequence of variables owned by this module and it's submodules.
/// Note: this method uses reflection to find variables on the current instance
/// and submodules. For performance reasons you may wish to cache the result
/// of calling this method if you don't expect the return value to change.
/// </summary>
/// <returns>
/// A sequence of variables for the current module (sorted by attribute
/// name) followed by variables from all submodules recursively (breadth
/// first).
/// </returns>
public ValueTuple trainable_variables => throw new NotImplementedException();

/// <summary>
/// Sequence of all sub-modules.
/// Submodules are modules which are properties of this module, or found as
/// properties of modules which are properties of this module (and so on).
/// </summary>
/// <returns>
/// A sequence of all submodules.
/// </returns>
public ValueTuple submodules => throw new NotImplementedException();

public Module(string name = null)
{
if(name == null)
name = Module.camel_to_snake(this.GetType().Name);
else
if (!valid_identifier(name))
new ValueError(name +
" is not a valid module name. Module names must be valid Python " +
"identifiers (e.g. a valid class name).");
this._name = name;
//if (tf2.enabled())
// using( var scope_name = ops.name_scope_v2(name) )
// this._name_scope = ops.name_scope_v2(scope_name);
//else
using(var scope_name = ops.name_scope(name))
this._scope_name = scope_name;
}

public object with_name_scope(Func<object> method)
{
throw new NotImplementedException();
}

// NOTE : Below are all the static functions
static string _CAMEL_TO_SNAKE_R = "((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))";
public static string camel_to_snake(string value)
{
return Regex.Match(value, _CAMEL_TO_SNAKE_R).Result("_${1}").ToLower();
}
static string _VALID_IDENTIFIER = "^[a-zA-Z_]([a-zA-Z0-9_])*$";
public static bool valid_identifier(string name)
{
return Regex.Match(name, _VALID_IDENTIFIER).Success;
}
public static Array _flatten_module(Module module,
bool recursive,
object predicate,
object attribute_traversal_key,
object attributes_to_ignore,
bool with_path,
ValueTuple module_path = default(ValueTuple),
bool seen=false)
{
throw new System.NotImplementedException();
}
}
}

+ 49
- 0
src/TensorFlowNET.Core/tf2.cs View File

@@ -0,0 +1,49 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;

namespace Tensorflow
{
public static class tf2
{
public static bool _force_enable;
internal static int flag = 0;
public static void enable()
{
_force_enable = true;
flag = 1;
}
public static void disable()
{
_force_enable = false;
flag = 1;
}
public static bool enabled()
{
if(flag == 0)
{
var data = "";
((System.Collections.Generic.Dictionary<string, string>)System.Environment.GetEnvironmentVariables()).TryGetValue("TF2_BEHAVIOR", out data);
if(string.IsNullOrEmpty(data))
data = "0";
return data != "0";
}
else
return _force_enable;
}
}
}

Loading…
Cancel
Save