| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Operations; | |||||
| namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
| { | { | ||||
| @@ -13,7 +14,7 @@ namespace Tensorflow.Gradients | |||||
| var input_op = op.inputs[0].op; | var input_op = op.inputs[0].op; | ||||
| var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
| var op_ctxt = control_flow_util.GetOutputContext(input_op); | var op_ctxt = control_flow_util.GetOutputContext(input_op); | ||||
| var pred = op_ctxt.pred; | |||||
| var pred = (op_ctxt as CondContext).pred; | |||||
| var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | var results = control_flow_ops._SwitchRefOrTensor(grad, pred, name: "cond_grad"); | ||||
| return new Tensor[] { results.Item1, results.Item2 }; | return new Tensor[] { results.Item1, results.Item2 }; | ||||
| @@ -7,7 +7,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class Operation | public partial class Operation | ||||
| { | { | ||||
| private CondContext _control_flow_context; | |||||
| private IControlFlowContext _control_flow_context; | |||||
| /// <summary> | /// <summary> | ||||
| /// Add this op to its control flow context. | /// Add this op to its control flow context. | ||||
| @@ -28,12 +28,12 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public void _set_control_flow_context(CondContext ctx) | |||||
| public void _set_control_flow_context(IControlFlowContext ctx) | |||||
| { | { | ||||
| _control_flow_context = ctx; | _control_flow_context = ctx; | ||||
| } | } | ||||
| public CondContext _get_control_flow_context() | |||||
| public IControlFlowContext _get_control_flow_context() | |||||
| { | { | ||||
| return _control_flow_context; | return _control_flow_context; | ||||
| } | } | ||||
| @@ -102,8 +102,12 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| // Dict mapping op name to file and line information for op colocation | |||||
| // context managers. | |||||
| _control_flow_context = graph._get_control_flow_context(); | |||||
| // This will be set by self.inputs. | // This will be set by self.inputs. | ||||
| if(op_def == null) | |||||
| if (op_def == null) | |||||
| op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
| var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | ||||
| @@ -27,7 +27,7 @@ namespace Tensorflow | |||||
| return op.type == "Switch" || op.type == "RefSwitch"; | return op.type == "Switch" || op.type == "RefSwitch"; | ||||
| } | } | ||||
| public static CondContext GetOutputContext(Operation op) | |||||
| public static IControlFlowContext GetOutputContext(Operation op) | |||||
| { | { | ||||
| var ctxt = op._get_control_flow_context(); | var ctxt = op._get_control_flow_context(); | ||||
| @@ -4,7 +4,7 @@ | |||||
| <TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <Version>0.4.2</Version> | |||||
| <Version>0.5.0</Version> | |||||
| <Authors>Haiping Chen</Authors> | <Authors>Haiping Chen</Authors> | ||||
| <Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
| @@ -13,14 +13,13 @@ | |||||
| <RepositoryType>git</RepositoryType> | <RepositoryType>git</RepositoryType> | ||||
| <PackageProjectUrl>https://github.com/SciSharp</PackageProjectUrl> | <PackageProjectUrl>https://github.com/SciSharp</PackageProjectUrl> | ||||
| <PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | <PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | ||||
| <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags> | |||||
| <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | |||||
| <Description>Google's TensorFlow binding in .NET Standard. | <Description>Google's TensorFlow binding in .NET Standard. | ||||
| Docs: https://tensorflownet.readthedocs.io</Description> | Docs: https://tensorflownet.readthedocs.io</Description> | ||||
| <AssemblyVersion>0.4.2.0</AssemblyVersion> | |||||
| <PackageReleaseNotes>Added ConfigProto to control CPU and GPU resource. | |||||
| Fixed import name scope issue.</PackageReleaseNotes> | |||||
| <AssemblyVersion>0.5.0.0</AssemblyVersion> | |||||
| <PackageReleaseNotes>Add a lot of APIs to build neural networks model</PackageReleaseNotes> | |||||
| <LangVersion>7.2</LangVersion> | <LangVersion>7.2</LangVersion> | ||||
| <FileVersion>0.4.2.0</FileVersion> | |||||
| <FileVersion>0.5.0.0</FileVersion> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||