Browse Source

Merge pull request #5 from AsakusaRinne/add_cv_compatibility

Add a tolorance to equivalence of NDArray.
pull/1047/head
Rinne GitHub 2 years ago
parent
commit
021d37cc18
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 3 deletions
  1. +19
    -2
      src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs
  2. +17
    -0
      src/TensorflowNET.Hub/Tensorflow.Hub.csproj
  3. +1
    -1
      test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj

+ 19
- 2
src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs View File

@@ -33,7 +33,16 @@ namespace Tensorflow.NumPy
return Scalar(false);
if(rhs is null)
return Scalar(false);
return new NDArray(math_ops.equal(lhs, rhs));
// TODO(Rinne): use np.allclose instead.
if (lhs.dtype.is_floating() || rhs.dtype.is_floating())
{
var diff = tf.abs(lhs - rhs);
return new NDArray(gen_math_ops.less(diff, new NDArray(1e-5).astype(diff.dtype)));
}
else
{
return new NDArray(math_ops.equal(lhs, rhs));
}
}
[AutoNumPy]
public static NDArray operator !=(NDArray lhs, NDArray rhs)
@@ -42,7 +51,15 @@ namespace Tensorflow.NumPy
return Scalar(false);
if(lhs is null || rhs is null)
return Scalar(true);
return new NDArray(math_ops.not_equal(lhs, rhs));
if (lhs.dtype.is_floating() || rhs.dtype.is_floating())
{
var diff = tf.abs(lhs - rhs);
return new NDArray(gen_math_ops.greater_equal(diff, new NDArray(1e-5).astype(diff.dtype)));
}
else
{
return new NDArray(math_ops.not_equal(lhs, rhs));
}
}
}
}

+ 17
- 0
src/TensorflowNET.Hub/Tensorflow.Hub.csproj View File

@@ -5,6 +5,23 @@
<LangVersion>10</LangVersion>
<Nullable>enable</Nullable>
<Version>1.0.0</Version>
<PackageId>TensorFlow.NET.Hub</PackageId>
<PackageLicenseExpression>Apache2.0</PackageLicenseExpression>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly>
<Authors>Yaohui Liu, Haiping Chen</Authors>
<Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
<Copyright>Apache 2.0, Haiping Chen $([System.DateTime]::UtcNow.ToString(yyyy))</Copyright>
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl>
<RepositoryType>git</RepositoryType>
<PackageProjectUrl>http://scisharpstack.org</PackageProjectUrl>
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl>
<PackageTags>TensorFlow, SciSharp, Machine Learning, Deep Learning, TensorFlow Hub, TensorFlow.NET, TF.NET, AI</PackageTags>
<Description>
Google's TensorFlow Hub full binding in .NET Standard.
A library for transfer learning with TensorFlow.NET.
</Description>
</PropertyGroup>

<ItemGroup>


+ 1
- 1
test/TensorflowNET.Hub.Unittest/Tensorflow.Hub.Unittest.csproj View File

@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net6</TargetFramework>


Loading…
Cancel
Save