| @@ -30,6 +30,10 @@ namespace Tensorflow | |||||
| if (src_graph == null) | if (src_graph == null) | ||||
| src_graph = ops.get_default_graph(); | src_graph = ops.get_default_graph(); | ||||
| // If src_graph is a _FuncGraph (i.e. a function body), gather it and all | |||||
| // ancestor graphs. This is necessary for correctly handling captured values. | |||||
| var curr_graph = src_graph; | |||||
| var ys1 = _AsList(ys); | var ys1 = _AsList(ys); | ||||
| var xs1 = _AsList(xs); | var xs1 = _AsList(xs); | ||||
| List<Tensor> grad_ys1 = null; | List<Tensor> grad_ys1 = null; | ||||
| @@ -47,7 +51,10 @@ namespace Tensorflow | |||||
| string grad_scope = ""; | string grad_scope = ""; | ||||
| using (var namescope = new ops.name_scope<Tensor>(name, "gradients", values: all)) | using (var namescope = new ops.name_scope<Tensor>(name, "gradients", values: all)) | ||||
| { | |||||
| grad_scope = namescope; | grad_scope = namescope; | ||||
| } | |||||
| } | } | ||||
| private static List<Tensor> _AsList(object ys) | private static List<Tensor> _AsList(object ys) | ||||
| @@ -173,7 +173,6 @@ namespace Tensorflow | |||||
| string new_stack = ""; | string new_stack = ""; | ||||
| if (name.EndsWith("/")) | if (name.EndsWith("/")) | ||||
| new_stack = ops._name_from_scope_name(name); | new_stack = ops._name_from_scope_name(name); | ||||
| else | else | ||||
| @@ -15,14 +15,15 @@ namespace Tensorflow | |||||
| var g = ops.get_default_graph(); | var g = ops.get_default_graph(); | ||||
| var op_def = g.GetOpDef(op_type_name); | var op_def = g.GetOpDef(op_type_name); | ||||
| // Default name if not specified. | |||||
| if (String.IsNullOrEmpty(name)) | if (String.IsNullOrEmpty(name)) | ||||
| { | |||||
| name = op_type_name; | name = op_type_name; | ||||
| } | |||||
| string scope = ""; | |||||
| using (var namescope = new ops.name_scope<object>(name)) | |||||
| scope = namescope; | |||||
| // Check for deprecation | |||||
| if(op_def.Deprecation != null && op_def.Deprecation.Version > 0) | |||||
| { | |||||
| } | |||||
| var default_type_attr_map = new Dictionary<string, object>(); | var default_type_attr_map = new Dictionary<string, object>(); | ||||
| foreach (var attr_def in op_def.Attr) | foreach (var attr_def in op_def.Attr) | ||||
| @@ -39,101 +40,107 @@ namespace Tensorflow | |||||
| var inputs = new List<Tensor>(); | var inputs = new List<Tensor>(); | ||||
| var input_types = new List<TF_DataType>(); | var input_types = new List<TF_DataType>(); | ||||
| // Perform input type inference | |||||
| foreach (var input_arg in op_def.InputArg) | |||||
| string scope = ""; | |||||
| using (var namescope = new ops.name_scope<object>(name)) | |||||
| { | { | ||||
| var input_name = input_arg.Name; | |||||
| if (keywords[input_name] is double int_value) | |||||
| { | |||||
| keywords[input_name] = constant_op.Constant(int_value, input_name); | |||||
| } | |||||
| scope = namescope; | |||||
| if (keywords[input_name] is Tensor value) | |||||
| // Perform input type inference | |||||
| foreach (var input_arg in op_def.InputArg) | |||||
| { | { | ||||
| if (keywords.ContainsKey(input_name)) | |||||
| var input_name = input_arg.Name; | |||||
| if (keywords[input_name] is double int_value) | |||||
| { | { | ||||
| inputs.Add(value); | |||||
| keywords[input_name] = constant_op.Constant(int_value, input_name); | |||||
| } | } | ||||
| if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | |||||
| if (keywords[input_name] is Tensor value) | |||||
| { | { | ||||
| attrs[input_arg.TypeAttr] = value.dtype; | |||||
| if (keywords.ContainsKey(input_name)) | |||||
| { | |||||
| inputs.Add(value); | |||||
| } | |||||
| if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | |||||
| { | |||||
| attrs[input_arg.TypeAttr] = value.dtype; | |||||
| } | |||||
| if (input_arg.IsRef) | |||||
| { | |||||
| } | |||||
| else | |||||
| { | |||||
| input_types.Add(value.dtype); | |||||
| } | |||||
| } | } | ||||
| } | |||||
| if (input_arg.IsRef) | |||||
| { | |||||
| } | |||||
| else | |||||
| // Process remaining attrs | |||||
| foreach (var attr in op_def.Attr) | |||||
| { | |||||
| if (keywords.ContainsKey(attr.Name)) | |||||
| { | { | ||||
| input_types.Add(value.dtype); | |||||
| attrs[attr.Name] = keywords[attr.Name]; | |||||
| } | } | ||||
| } | } | ||||
| } | |||||
| // Process remaining attrs | |||||
| foreach (var attr in op_def.Attr) | |||||
| { | |||||
| if (keywords.ContainsKey(attr.Name)) | |||||
| // Convert attr values to AttrValue protos. | |||||
| var attr_protos = new Dictionary<string, AttrValue>(); | |||||
| foreach (var attr_def in op_def.Attr) | |||||
| { | { | ||||
| attrs[attr.Name] = keywords[attr.Name]; | |||||
| } | |||||
| } | |||||
| var key = attr_def.Name; | |||||
| var value = attrs[key]; | |||||
| var attr_value = new AttrValue(); | |||||
| // Convert attr values to AttrValue protos. | |||||
| var attr_protos = new Dictionary<string, AttrValue>(); | |||||
| foreach (var attr_def in op_def.Attr) | |||||
| { | |||||
| var key = attr_def.Name; | |||||
| var value = attrs[key]; | |||||
| var attr_value = new AttrValue(); | |||||
| switch (attr_def.Type) | |||||
| { | |||||
| case "string": | |||||
| attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); | |||||
| break; | |||||
| case "type": | |||||
| attr_value.Type = _MakeType((TF_DataType)value, attr_def); | |||||
| break; | |||||
| case "bool": | |||||
| attr_value.B = (bool)value; | |||||
| break; | |||||
| case "shape": | |||||
| attr_value.Shape = value == null ? | |||||
| attr_def.DefaultValue.Shape : | |||||
| tensor_util.as_shape((long[])value); | |||||
| break; | |||||
| default: | |||||
| throw new InvalidDataException($"attr_def.Type {attr_def.Type}"); | |||||
| } | |||||
| switch (attr_def.Type) | |||||
| { | |||||
| case "string": | |||||
| attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value); | |||||
| break; | |||||
| case "type": | |||||
| attr_value.Type = _MakeType((TF_DataType)value, attr_def); | |||||
| break; | |||||
| case "bool": | |||||
| attr_value.B = (bool)value; | |||||
| break; | |||||
| case "shape": | |||||
| attr_value.Shape = value == null ? | |||||
| attr_def.DefaultValue.Shape : | |||||
| tensor_util.as_shape((long[])value); | |||||
| break; | |||||
| default: | |||||
| throw new InvalidDataException($"attr_def.Type {attr_def.Type}"); | |||||
| } | |||||
| attr_protos[key] = attr_value; | |||||
| } | |||||
| attr_protos[key] = attr_value; | |||||
| } | |||||
| // Determine output types (possibly using attrs) | |||||
| var output_types = new List<TF_DataType>(); | |||||
| // Determine output types (possibly using attrs) | |||||
| var output_types = new List<TF_DataType>(); | |||||
| foreach (var arg in op_def.OutputArg) | |||||
| { | |||||
| if (!String.IsNullOrEmpty(arg.NumberAttr)) | |||||
| foreach (var arg in op_def.OutputArg) | |||||
| { | { | ||||
| if (!String.IsNullOrEmpty(arg.NumberAttr)) | |||||
| { | |||||
| } | |||||
| else if (!String.IsNullOrEmpty(arg.TypeAttr)) | |||||
| { | |||||
| output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); | |||||
| } | |||||
| } | } | ||||
| else if (!String.IsNullOrEmpty(arg.TypeAttr)) | |||||
| { | |||||
| output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); | |||||
| } | |||||
| } | |||||
| // Add Op to graph | |||||
| var op = g.create_op(op_type_name, inputs, output_types.ToArray(), | |||||
| name: scope, | |||||
| input_types: input_types.ToArray(), | |||||
| attrs: attr_protos, | |||||
| op_def: op_def); | |||||
| // Add Op to graph | |||||
| var op = g.create_op(op_type_name, inputs, output_types.ToArray(), | |||||
| name: scope, | |||||
| input_types: input_types.ToArray(), | |||||
| attrs: attr_protos, | |||||
| op_def: op_def); | |||||
| return op; | |||||
| return op; | |||||
| } | |||||
| } | } | ||||
| public DataType _MakeType(TF_DataType v, AttrDef attr_def) | public DataType _MakeType(TF_DataType v, AttrDef attr_def) | ||||
| @@ -4,9 +4,9 @@ | |||||
| <TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <Version>0.0.2</Version> | |||||
| <Version>0.0.3</Version> | |||||
| <Authors>Haiping Chen</Authors> | <Authors>Haiping Chen</Authors> | ||||
| <Company>SciSharp.org</Company> | |||||
| <Company>SciSharp STACK</Company> | |||||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
| <Copyright>Apache 2.0</Copyright> | <Copyright>Apache 2.0</Copyright> | ||||
| <RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | <RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | ||||
| @@ -16,7 +16,7 @@ | |||||
| <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags> | <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags> | ||||
| <Description>Google's TensorFlow binding in .NET Standard. | <Description>Google's TensorFlow binding in .NET Standard. | ||||
| Docs: https://tensorflownet.readthedocs.io</Description> | Docs: https://tensorflownet.readthedocs.io</Description> | ||||
| <AssemblyVersion>0.0.2.0</AssemblyVersion> | |||||
| <AssemblyVersion>0.0.3.0</AssemblyVersion> | |||||
| <PackageReleaseNotes>API updated</PackageReleaseNotes> | <PackageReleaseNotes>API updated</PackageReleaseNotes> | ||||
| <LangVersion>7.2</LangVersion> | <LangVersion>7.2</LangVersion> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| @@ -4,6 +4,10 @@ using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| /// <summary> | |||||
| /// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. | |||||
| /// The enum values here are identical to corresponding values in types.proto. | |||||
| /// </summary> | |||||
| public enum TF_DataType | public enum TF_DataType | ||||
| { | { | ||||
| DtInvalid = 0, | DtInvalid = 0, | ||||
| @@ -30,6 +34,8 @@ namespace Tensorflow | |||||
| TF_RESOURCE = 20, | TF_RESOURCE = 20, | ||||
| TF_VARIANT = 21, | TF_VARIANT = 21, | ||||
| TF_UINT32 = 22, | TF_UINT32 = 22, | ||||
| TF_UINT64 = 23 | |||||
| TF_UINT64 = 23, | |||||
| DtDoubleRef = 102, // DT_DOUBLE_REF | |||||
| } | } | ||||
| } | } | ||||
| @@ -19,7 +19,10 @@ namespace Tensorflow | |||||
| public Graph Graph => op.Graph; | public Graph Graph => op.Graph; | ||||
| public Operation op { get; } | public Operation op { get; } | ||||
| public string name; | |||||
| /// <summary> | |||||
| /// The string name of this tensor. | |||||
| /// </summary> | |||||
| public string name => $"{(op == null ? "Operation was not named" : $"{op.Name}:{value_index}")}"; | |||||
| public int value_index { get; } | public int value_index { get; } | ||||
| @@ -222,7 +225,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| return $"{name} {dtype} {rank} {string.Join(",", shape)}"; | |||||
| return $"{name} {dtype.ToString()} {rank} {string.Join(",", shape)}"; | |||||
| } | } | ||||
| public void Dispose() | public void Dispose() | ||||
| @@ -17,6 +17,10 @@ namespace Tensorflow | |||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| public double LearningRate { get; set; } | public double LearningRate { get; set; } | ||||
| public Tensor LearningRateTensor { get; set; } | public Tensor LearningRateTensor { get; set; } | ||||
| public bool _use_locking; | |||||
| public Dictionary<string, object> _slots; | |||||
| public Dictionary<string, object> _non_slot_dict; | |||||
| public Dictionary<string, object> _deferred_slot_restorations; | |||||
| public Optimizer(double learning_rate, bool use_locking, string name = "") | public Optimizer(double learning_rate, bool use_locking, string name = "") | ||||
| { | { | ||||
| @@ -24,6 +28,11 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("Must specify the optimizer name"); | throw new NotImplementedException("Must specify the optimizer name"); | ||||
| Name = name; | Name = name; | ||||
| _use_locking = use_locking; | |||||
| // Dictionary of slots. | |||||
| _slots = new Dictionary<string, object>(); | |||||
| _non_slot_dict = new Dictionary<string, object>(); | |||||
| _deferred_slot_restorations = new Dictionary<string, object>(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -68,7 +77,7 @@ namespace Tensorflow | |||||
| break; | break; | ||||
| } | } | ||||
| var processors = var_list.Select(v => optimizer._get_processor(v)); | |||||
| var processors = var_list.Select(v => optimizer._get_processor(v)).ToList(); | |||||
| var var_refs = processors.Select(x => x.target()).ToList(); | var var_refs = processors.Select(x => x.target()).ToList(); | ||||
| gradients_impl.gradients(loss, var_refs, grad_ys: grad_loss, | gradients_impl.gradients(loss, var_refs, grad_ys: grad_loss, | ||||
| @@ -79,6 +79,17 @@ namespace Tensorflow | |||||
| // have an issue if these other variables aren't initialized first by | // have an issue if these other variables aren't initialized first by | ||||
| // using their initialized_value() method. | // using their initialized_value() method. | ||||
| var _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op; | |||||
| if (!String.IsNullOrEmpty(caching_device)) | |||||
| { | |||||
| } | |||||
| else | |||||
| { | |||||
| } | |||||
| ops.add_to_collections(collections, this); | ops.add_to_collections(collections, this); | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | |||||
| using NumSharp.Core; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| @@ -33,5 +34,31 @@ namespace Tensorflow | |||||
| return new Tensor(_op, 0, dtype); | return new Tensor(_op, 0, dtype); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Update 'ref' by assigning 'value' to it | |||||
| /// </summary> | |||||
| /// <param name="REF"></param> | |||||
| /// <param name="value"></param> | |||||
| /// <param name="validate_shape"></param> | |||||
| /// <param name="use_locking"></param> | |||||
| /// <param name="name"></param> | |||||
| public static Tensor assign(Tensor tensor, Tensor value, | |||||
| bool validate_shape = true, | |||||
| bool use_locking = true, | |||||
| string name = "") | |||||
| { | |||||
| var keywords = new Dictionary<string, object>(); | |||||
| keywords.Add("ref", tensor); | |||||
| keywords.Add("value", value); | |||||
| keywords.Add("validate_shape", validate_shape); | |||||
| keywords.Add("use_locking", use_locking); | |||||
| var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords); | |||||
| var _result = _op.outputs[0]; | |||||
| var _inputs_flat = _op.inputs; | |||||
| return _result; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -21,8 +21,6 @@ namespace Tensorflow | |||||
| _default_name = default_name; | _default_name = default_name; | ||||
| _values = values; | _values = values; | ||||
| _ctx = new Context(); | _ctx = new Context(); | ||||
| _name_scope = __enter__(); | |||||
| } | } | ||||
| public string __enter__() | public string __enter__() | ||||
| @@ -38,8 +36,10 @@ namespace Tensorflow | |||||
| if (g == null) | if (g == null) | ||||
| g = get_default_graph(); | g = get_default_graph(); | ||||
| return g.name_scope(_name); ; | |||||
| _name_scope = g.name_scope(_name); | |||||
| return _name_scope; | |||||
| } | } | ||||
| public void Dispose() | public void Dispose() | ||||
| @@ -48,9 +48,13 @@ namespace Tensorflow | |||||
| g._name_stack = g.old_stack; | g._name_stack = g.old_stack; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// __enter__() | |||||
| /// </summary> | |||||
| /// <param name="ns"></param> | |||||
| public static implicit operator string(name_scope<T> ns) | public static implicit operator string(name_scope<T> ns) | ||||
| { | { | ||||
| return ns._name_scope; | |||||
| return ns.__enter__(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -7,7 +7,7 @@ | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="NumSharp" Version="0.6.5" /> | <PackageReference Include="NumSharp" Version="0.6.5" /> | ||||
| <PackageReference Include="TensorFlow.NET" Version="0.0.2" /> | |||||
| <PackageReference Include="TensorFlow.NET" Version="0.0.3" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -20,7 +20,7 @@ | |||||
| <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | ||||
| <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | ||||
| <PackageReference Include="NumSharp" Version="0.6.5" /> | <PackageReference Include="NumSharp" Version="0.6.5" /> | ||||
| <PackageReference Include="TensorFlow.NET" Version="0.0.2" /> | |||||
| <PackageReference Include="TensorFlow.NET" Version="0.0.3" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||