From bed0013749d44002c7f51daefb67a0fa3f5679f1 Mon Sep 17 00:00:00 2001 From: arnavdas88 Date: Tue, 27 Aug 2019 12:32:50 +0530 Subject: [PATCH] tf2.0 --- src/TensorFlowNET.Core/Keras/Layers/Layer.cs | 3 +- src/TensorFlowNET.Core/Layers/Layer.cs | 2 +- src/TensorFlowNET.Core/Module/Module.cs | 142 +++++++++++++++++++ src/TensorFlowNET.Core/tf2.cs | 49 +++++++ 4 files changed, 194 insertions(+), 2 deletions(-) create mode 100644 src/TensorFlowNET.Core/Module/Module.cs create mode 100644 src/TensorFlowNET.Core/tf2.cs diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index 6681ec56..5f2270d3 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -20,6 +20,7 @@ using System.Linq; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; using Tensorflow.Train; +using Tensorflow.Module; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers @@ -32,7 +33,7 @@ namespace Tensorflow.Keras.Layers /// /// tensorflow\python\keras\engine\base_layer.py /// - public class Layer : AutoTrackable + public class Layer : Module.Module { /// /// Indicates whether `build` needs to be called upon layer call, to create diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index a3ae3356..b86c4063 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -128,7 +128,7 @@ namespace Tensorflow.Layers else { init_graph = default_graph; - existing_variables = variables.global_variables().ToArray(); + existing_variables = Tensorflow.variables.global_variables().ToArray(); } if(dtype == TF_DataType.DtInvalid) diff --git a/src/TensorFlowNET.Core/Module/Module.cs b/src/TensorFlowNET.Core/Module/Module.cs new file mode 100644 index 00000000..9e310b14 --- /dev/null +++ b/src/TensorFlowNET.Core/Module/Module.cs @@ -0,0 +1,142 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; +using Tensorflow; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow.Module +{ + /// + /// Base neural network module class. + /// A module is a named container for `tf.Variable`s, other `tf.Module`s and + /// functions which apply to user input. For example a dense layer in a neural + /// network might be implemented as a `tf.Module`: + /// + /// tensorflow/python/module/module.py + /// + public class Module : AutoTrackable + { + SortedSet _TF_MODULE_IGNORED_PROPERTIES = new SortedSet(){ "_self_unconditional_checkpoint_dependencies", "_self_unconditional_dependency_names"}; + + protected string _name; + /// + /// Returns the name of this module as passed or determined in the ctor. + /// NOTE: This is not the same as the `self.name_scope.name` which includes + /// parent module names. + /// + public string name => _name; + + protected ops.NameScope _name_scope; + protected ops.NameScope _scope_name; + /// + /// Returns a `tf.name_scope` instance for this class. + /// + public ops.NameScope name_scope + { + get{ + if(tf2.enabled()) + return this._name_scope; + return ops.name_scope(_scope_name); + } + } + /// + /// Sequence of variables owned by this module and it's submodules. + /// Note: this method uses reflection to find variables on the current instance + /// and submodules. For performance reasons you may wish to cache the result + /// of calling this method if you don't expect the return value to change. + /// + /// + /// A sequence of variables for the current module (sorted by attribute + /// name) followed by variables from all submodules recursively (breadth + /// first). + /// + public ValueTuple variables => throw new NotImplementedException(); + + /// + /// Sequence of variables owned by this module and it's submodules. + /// Note: this method uses reflection to find variables on the current instance + /// and submodules. For performance reasons you may wish to cache the result + /// of calling this method if you don't expect the return value to change. + /// + /// + /// A sequence of variables for the current module (sorted by attribute + /// name) followed by variables from all submodules recursively (breadth + /// first). + /// + public ValueTuple trainable_variables => throw new NotImplementedException(); + + /// + /// Sequence of all sub-modules. + /// Submodules are modules which are properties of this module, or found as + /// properties of modules which are properties of this module (and so on). + /// + /// + /// A sequence of all submodules. + /// + public ValueTuple submodules => throw new NotImplementedException(); + + public Module(string name = null) + { + if(name == null) + name = Module.camel_to_snake(this.GetType().Name); + else + if (!valid_identifier(name)) + new ValueError(name + + " is not a valid module name. Module names must be valid Python " + + "identifiers (e.g. a valid class name)."); + this._name = name; + //if (tf2.enabled()) + // using( var scope_name = ops.name_scope_v2(name) ) + // this._name_scope = ops.name_scope_v2(scope_name); + //else + using(var scope_name = ops.name_scope(name)) + this._scope_name = scope_name; + } + + public object with_name_scope(Func method) + { + throw new NotImplementedException(); + } + + // NOTE : Below are all the static functions + static string _CAMEL_TO_SNAKE_R = "((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))"; + public static string camel_to_snake(string value) + { + return Regex.Match(value, _CAMEL_TO_SNAKE_R).Result("_${1}").ToLower(); + } + static string _VALID_IDENTIFIER = "^[a-zA-Z_]([a-zA-Z0-9_])*$"; + public static bool valid_identifier(string name) + { + return Regex.Match(name, _VALID_IDENTIFIER).Success; + } + public static Array _flatten_module(Module module, + bool recursive, + object predicate, + object attribute_traversal_key, + object attributes_to_ignore, + bool with_path, + ValueTuple module_path = default(ValueTuple), + bool seen=false) + { + throw new System.NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/tf2.cs b/src/TensorFlowNET.Core/tf2.cs new file mode 100644 index 00000000..eb2632b5 --- /dev/null +++ b/src/TensorFlowNET.Core/tf2.cs @@ -0,0 +1,49 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; + +namespace Tensorflow +{ + public static class tf2 + { + public static bool _force_enable; + internal static int flag = 0; + public static void enable() + { + _force_enable = true; + flag = 1; + } + public static void disable() + { + _force_enable = false; + flag = 1; + } + public static bool enabled() + { + if(flag == 0) + { + var data = ""; + ((System.Collections.Generic.Dictionary)System.Environment.GetEnvironmentVariables()).TryGetValue("TF2_BEHAVIOR", out data); + if(string.IsNullOrEmpty(data)) + data = "0"; + return data != "0"; + } + else + return _force_enable; + } + } +}