From 2dd7ac91952357137880d5371931954cc14b7f75 Mon Sep 17 00:00:00 2001 From: Deepak Battini Date: Fri, 10 Jan 2020 17:31:30 +1030 Subject: [PATCH] Base layer skeleton added --- src/TensorFlowNET.Keras/Engine/BaseLayer.cs | 67 ++++++- src/TensorFlowNET.Keras/Layers/Layer.cs | 202 +++++++++++++++++++- 2 files changed, 263 insertions(+), 6 deletions(-) diff --git a/src/TensorFlowNET.Keras/Engine/BaseLayer.cs b/src/TensorFlowNET.Keras/Engine/BaseLayer.cs index 35f65357..36c69843 100644 --- a/src/TensorFlowNET.Keras/Engine/BaseLayer.cs +++ b/src/TensorFlowNET.Keras/Engine/BaseLayer.cs @@ -1,10 +1,73 @@ -using System; +using Keras.Layers; +using NumSharp; +using System; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.Engine { - public class Layer + public class TensorFlowOpLayer : Layer { + public TensorFlowOpLayer(string node_def, string name, NDArray[] constants = null, bool trainable = true, string dtype = null) + { + + } + + public override void call(Tensor[] inputs) + { + throw new NotImplementedException(); + } + + public override Dictionary get_config() + { + throw new NotImplementedException(); + } + + private NodeDef _make_node_def(Graph graph) => throw new NotImplementedException(); + + private Tensor[] _make_op(Tensor[] inputs) => throw new NotImplementedException(); + + private Tensor[] _defun_call(Tensor[] inputs) => throw new NotImplementedException(); + } + + public class AddLoss : Layer + { + public AddLoss(bool unconditional) + { + throw new NotImplementedException(); + } + + public override void call(Tensor[] inputs) + { + throw new NotImplementedException(); + } + + public override Dictionary get_config() + { + throw new NotImplementedException(); + } + } + + public class AddMetric : Layer + { + public AddMetric(string aggregation = null, string metric_name = null) + { + throw new NotImplementedException(); + } + + public override void call(Tensor[] inputs) + { + throw new NotImplementedException(); + } + + public override Dictionary get_config() + { + throw new NotImplementedException(); + } + } + + public class KerasHistory + { + } } diff --git a/src/TensorFlowNET.Keras/Layers/Layer.cs b/src/TensorFlowNET.Keras/Layers/Layer.cs index 656e273e..eb231fad 100644 --- a/src/TensorFlowNET.Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Keras/Layers/Layer.cs @@ -5,6 +5,7 @@ using Tensorflow; using Tensorflow.Keras.Constraints; using Tensorflow.Keras.Initializers; using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Regularizers; namespace Keras.Layers @@ -87,7 +88,7 @@ namespace Keras.Layers } } - public Tensor[] weights + private Tensor[] _weights { get { @@ -167,6 +168,38 @@ namespace Keras.Layers } } + public Tensor[] variables + { + get + { + return _weights; + } + } + + public Tensor[] trainable_variables + { + get + { + return trainable_weights; + } + } + + public Tensor[] non_trainable_variables + { + get + { + return non_trainable_weights; + } + } + + private string _compute_dtype + { + get + { + throw new NotImplementedException(); + } + } + public Layer(bool trainable = true, string name = null, string dtype = null, bool @dynamic = false, Dictionary kwargs = null) { @@ -174,7 +207,7 @@ namespace Keras.Layers public void build(TensorShape shape) => throw new NotImplementedException(); - public void call(Tensor[] inputs) => throw new NotImplementedException(); + public virtual void call(Tensor[] inputs) => throw new NotImplementedException(); public void _add_trackable(dynamic trackable_object, bool trainable) => throw new NotImplementedException(); @@ -183,7 +216,7 @@ namespace Keras.Layers dynamic partitioner= null, bool? use_resource= null, VariableSynchronization synchronization= VariableSynchronization.Auto, VariableAggregation aggregation= VariableAggregation.None, Dictionary kwargs = null) => throw new NotImplementedException(); - public Dictionary get_config() => throw new NotImplementedException(); + public virtual Dictionary get_config() => throw new NotImplementedException(); public Layer from_config(Dictionary config) => throw new NotImplementedException(); @@ -224,5 +257,166 @@ namespace Keras.Layers public Tensor[] get_output_at(int node_index) => throw new NotImplementedException(); public int count_params() => throw new NotImplementedException(); - } + + private void _set_dtype_policy(string dtype) => throw new NotImplementedException(); + + private Tensor _maybe_cast_inputs(Tensor inputs) => throw new NotImplementedException(); + + private void _warn_about_input_casting(string input_dtype) => throw new NotImplementedException(); + + private string _name_scope() + { + return name; + } + + private string _obj_reference_counts + { + get + { + throw new NotImplementedException(); + } + } + + private dynamic _attribute_sentinel + { + get + { + throw new NotImplementedException(); + } + } + + private dynamic _call_full_argspec + { + get + { + throw new NotImplementedException(); + } + } + + private string[] _call_fn_args + { + get + { + throw new NotImplementedException(); + } + } + + private string[] _call_accepts_kwargs + { + get + { + throw new NotImplementedException(); + } + } + + private bool _should_compute_mask + { + get + { + throw new NotImplementedException(); + } + } + + private Tensor[] _eager_losses + { + get + { + throw new NotImplementedException(); + } + set + { + throw new NotImplementedException(); + } + } + + private dynamic _trackable_saved_model_saver + { + get + { + throw new NotImplementedException(); + } + } + + private string _object_identifier + { + get + { + throw new NotImplementedException(); + } + } + + private string _tracking_metadata + { + get + { + throw new NotImplementedException(); + } + } + + public Dictionary state + { + get + { + throw new NotImplementedException(); + } + set + { + throw new NotImplementedException(); + } + } + + private void _init_set_name(string name, bool zero_based= true) => throw new NotImplementedException(); + + private Metric _get_existing_metric(string name = null) => throw new NotImplementedException(); + + private void _eager_add_metric(Metric value, string aggregation= null, string name= null) => throw new NotImplementedException(); + + private void _symbolic_add_metric(Metric value, string aggregation = null, string name = null) => throw new NotImplementedException(); + + private void _handle_weight_regularization(string name, VariableV1 variable, Regularizer regularizer) => throw new NotImplementedException(); + + private void _handle_activity_regularization(Tensor[] inputs, Tensor[] outputs) => throw new NotImplementedException(); + + private void _set_mask_metadata(Tensor[] inputs, Tensor[] outputs, Tensor previous_mask) => throw new NotImplementedException(); + + private Tensor[] _collect_input_masks(Tensor[] inputs, Dictionary args, Dictionary kwargs) => throw new NotImplementedException(); + + private bool _call_arg_was_passed(string arg_name, Dictionary args, Dictionary kwargs, bool inputs_in_args= false) => throw new NotImplementedException(); + + private T _get_call_arg_value(string arg_name, Dictionary args, Dictionary kwargs, bool inputs_in_args = false) => throw new NotImplementedException(); + + private (Tensor[], Tensor[]) _set_connectivity_metadata_(Tensor[] inputs, Tensor[] outputs, Dictionary args, Dictionary kwargs) => throw new NotImplementedException(); + + private void _add_inbound_node(Tensor[] input_tensors, Tensor[] output_tensors, Dictionary args = null) => throw new NotImplementedException(); + + private AttrValue _get_node_attribute_at_index(int node_index, string attr, string attr_name) => throw new NotImplementedException(); + + private void _maybe_build(Tensor[] inputs) => throw new NotImplementedException(); + + private void _symbolic_call(Tensor[] inputs) => throw new NotImplementedException(); + + private Dictionary _get_trainable_state() => throw new NotImplementedException(); + + private void _set_trainable_state(bool trainable_state) => throw new NotImplementedException(); + + private void _maybe_create_attribute(string name, object default_value) => throw new NotImplementedException(); + + private void __delattr__(string name) => throw new NotImplementedException(); + + private void __setattr__(string name, object value) => throw new NotImplementedException(); + + private List _gather_children_attribute(string attribute) => throw new NotImplementedException(); + + private List _gather_unique_layers() => throw new NotImplementedException(); + + private List _gather_layers() => throw new NotImplementedException(); + + private bool _is_layer() => throw new NotImplementedException(); + + private void _init_call_fn_args() => throw new NotImplementedException(); + + public dynamic _list_extra_dependencies_for_serialization(dynamic serialization_cache) => throw new NotImplementedException(); + + public dynamic _list_functions_for_serialization(dynamic serialization_cache) => throw new NotImplementedException(); + } }