Browse Source

use SetOpAttrs instead of SetOpAttrWithDefaults

tags/v0.20
Oceania2018 5 years ago
parent
commit
10ebec48da
5 changed files with 51 additions and 37 deletions
  1. +12
    -0
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  2. +3
    -3
      src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  4. +19
    -20
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  5. +16
    -13
      src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs

+ 12
- 0
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -333,6 +333,18 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx); public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx);


/// <summary>
///
/// </summary>
/// <param name="ctx"></param>
/// <param name="device_name"></param>
/// <param name="op_name"></param>
/// <param name="name"></param>
/// <param name="args"></param>
/// <param name="input_size"></param>
/// <param name="set_op_attrs"></param>
/// <param name="status"></param>
/// <returns>EagerTensorHandle</returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_FastPathExecute(IntPtr ctx, public static extern IntPtr TFE_FastPathExecute(IntPtr ctx,
string device_name, string device_name,


+ 3
- 3
src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs View File

@@ -173,7 +173,7 @@ namespace Tensorflow.Eager
return true; return true;
} }


private static void SetOpAttrs(Context ctx, TFE_Op op, object[] attrs, int start_index, Status out_status)
public static void SetOpAttrs(Context ctx, TFE_Op op, object[] attrs, int start_index, Status out_status)
{ {
var len = attrs.Length; var len = attrs.Length;
for (int i = 0; i < len; i += 2) for (int i = 0; i < len; i += 2)
@@ -181,7 +181,7 @@ namespace Tensorflow.Eager
var key = attrs[start_index + i].ToString(); var key = attrs[start_index + i].ToString();
var value = attrs[start_index + i + 1]; var value = attrs[start_index + i + 1];


byte is_list = 0;
byte is_list = 0;
var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, out_status); var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, out_status);
if (!out_status.ok()) return; if (!out_status.ok()) return;
if (is_list != 0) if (is_list != 0)
@@ -205,7 +205,7 @@ namespace Tensorflow.Eager
/// <param name="attr_value"></param> /// <param name="attr_value"></param>
/// <param name="attr_list_sizes"></param> /// <param name="attr_list_sizes"></param>
/// <param name="status"></param> /// <param name="status"></param>
public static void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr,
private static void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr,
string attr_name, object attr_value, string attr_name, object attr_value,
Dictionary<string, long> attr_list_sizes, Dictionary<string, long> attr_list_sizes,
Status status) Status status)


+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -156,7 +156,7 @@ namespace Tensorflow
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Pack", name, "Pack", name,
values.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), 1, values.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), 1,
(op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "axis", axis, null, status),
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "axis", axis }, 0 , status),
status); status);
status.Check(true); status.Check(true);
return new EagerTensor(tensor); return new EagerTensor(tensor);


+ 19
- 20
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -138,16 +138,19 @@ namespace Tensorflow
"Mean", name, "Mean", name,
new IntPtr[] new IntPtr[]
{ {
(input as EagerTensor).EagerTensorHandle,
(axis as EagerTensor).EagerTensorHandle
input as EagerTensor,
axis as EagerTensor
}, 2, }, 2,
(op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "keep_dims", keep_dims, null, status),
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "keep_dims", keep_dims }, 0, status),
status); status);
status.Check(true); status.Check(true);
return new EagerTensor(tensor); return new EagerTensor(tensor);
} }
catch (Exception) catch (Exception)
{ {
/*tensors = c_api.TFE_Execute(tf.context, tf.context.device_name, op_name,
inputs, attrs, num_outputs);*/

return mean_eager_fallback(input as Tensor[], axis as Tensor, keep_dims: keep_dims, name: name, ctx: tf.context); return mean_eager_fallback(input as Tensor[], axis as Tensor, keep_dims: keep_dims, name: name, ctx: tf.context);
} }
} }
@@ -220,8 +223,8 @@ namespace Tensorflow
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Add", name, new IntPtr[] "Add", name, new IntPtr[]
{ {
(x as EagerTensor).EagerTensorHandle,
(y as EagerTensor).EagerTensorHandle
x as EagerTensor,
y as EagerTensor
}, 2, null, status); }, 2, null, status);
status.Check(true); status.Check(true);
return new EagerTensor(_result); return new EagerTensor(_result);
@@ -595,12 +598,8 @@ namespace Tensorflow
using var status = new Status(); using var status = new Status();
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Cast", name, "Cast", name,
new IntPtr[] { (x as EagerTensor).EagerTensorHandle }, 1,
(op) =>
{
wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "DstT", DstT, null, status);
wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "Truncate", Truncate, null, status);
},
new IntPtr[] { x as EagerTensor }, 1,
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "DstT", DstT, "Truncate", Truncate }, 0, status),
status); status);
status.Check(true); status.Check(true);
return new EagerTensor(tensor); return new EagerTensor(tensor);
@@ -649,8 +648,8 @@ namespace Tensorflow
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Sub", name, new IntPtr[] "Sub", name, new IntPtr[]
{ {
(x as EagerTensor).EagerTensorHandle,
(y as EagerTensor).EagerTensorHandle
x as EagerTensor,
y as EagerTensor
}, 2, null, status); }, 2, null, status);
status.Check(true); status.Check(true);
return new EagerTensor(_result); return new EagerTensor(_result);
@@ -743,8 +742,8 @@ namespace Tensorflow
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Mul", name, new IntPtr[] "Mul", name, new IntPtr[]
{ {
(x as EagerTensor).EagerTensorHandle,
(y as EagerTensor).EagerTensorHandle
x as EagerTensor,
y as EagerTensor
}, 2, null, status); }, 2, null, status);
status.Check(true); status.Check(true);
return new EagerTensor(_result); return new EagerTensor(_result);
@@ -785,8 +784,8 @@ namespace Tensorflow
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"RealDiv", name, new IntPtr[] "RealDiv", name, new IntPtr[]
{ {
(x as EagerTensor).EagerTensorHandle,
(y as EagerTensor).EagerTensorHandle
x as EagerTensor,
y as EagerTensor
}, 2, null, status); }, 2, null, status);
status.Check(true); status.Check(true);
return new EagerTensor(_result); return new EagerTensor(_result);
@@ -988,9 +987,9 @@ namespace Tensorflow
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Range", name, new IntPtr[] "Range", name, new IntPtr[]
{ {
(start as EagerTensor).EagerTensorHandle,
(limit as EagerTensor).EagerTensorHandle,
(delta as EagerTensor).EagerTensorHandle
start as EagerTensor,
limit as EagerTensor,
delta as EagerTensor
}, 3, null, status); }, 3, null, status);
status.Check(true); status.Check(true);
return new EagerTensor(tensor); return new EagerTensor(tensor);


+ 16
- 13
src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs View File

@@ -32,10 +32,10 @@ namespace Tensorflow
using var status = new Status(); using var status = new Status();
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"AssignVariableOp", name, "AssignVariableOp", name,
new[]
new IntPtr[]
{ {
(resource as EagerTensor).EagerTensorHandle,
(value as EagerTensor).EagerTensorHandle
resource as EagerTensor,
value as EagerTensor
}, 2, null, status); }, 2, null, status);
status.Check(true); status.Check(true);
return tensor; return tensor;
@@ -53,7 +53,7 @@ namespace Tensorflow
using var status = new Status(); using var status = new Status();
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"VarIsInitializedOp", name, "VarIsInitializedOp", name,
new[] { (resource as EagerTensor).EagerTensorHandle },
new IntPtr[] { resource as EagerTensor },
1, null, status); 1, null, status);
status.Check(true); status.Check(true);
return new EagerTensor(tensor); return new EagerTensor(tensor);
@@ -80,13 +80,15 @@ namespace Tensorflow
{ {
using var status = new Status(); using var status = new Status();
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"VarHandleOp", name, null, 0, op =>
{
wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "container", container, null, status);
wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "shared_name", shared_name, null, status);
wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "dtype", dtype, null, status);
wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "shape", shape.dims, null, status);
}, status);
"VarHandleOp", name, null, 0,
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[]
{
"container", container,
"shared_name", shared_name,
"dtype", dtype,
"shape", shape.dims
}, 0, status),
status);
status.Check(true); status.Check(true);
return new EagerTensor(tensor); return new EagerTensor(tensor);
} }
@@ -114,8 +116,9 @@ namespace Tensorflow
{ {
using var status = new Status(); using var status = new Status();
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"ReadVariableOp", name, new IntPtr[] { (resource as EagerTensor).EagerTensorHandle }, 1,
(op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "dtype", dtype, null, status),
"ReadVariableOp", name,
new IntPtr[] { resource as EagerTensor }, 1,
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "dtype", dtype }, 0, status),
status); status);
status.Check(true); status.Check(true);
return new EagerTensor(tensor); return new EagerTensor(tensor);


Loading…
Cancel
Save