Browse Source

Merge pull request #1140 from dogvane/master

fix same bug
tags/v0.110.4-Transformer-Model
Haiping GitHub 2 years ago
parent
commit
9cd86812e5
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 53 additions and 6 deletions
  1. +17
    -0
      src/TensorFlowNET.Core/Gradients/nn_grad.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  4. +1
    -1
      src/TensorFlowNET.Keras/Engine/Layer.Apply.cs
  5. +6
    -2
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  6. +1
    -1
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  7. +1
    -0
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
  8. +25
    -0
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs

+ 17
- 0
src/TensorFlowNET.Core/Gradients/nn_grad.cs View File

@@ -365,6 +365,23 @@ namespace Tensorflow.Gradients
};
}

[RegisterGradient("AvgPool")]
public static Tensor[] _AvgPoolGrad(Operation op, Tensor[] grads)
{
Tensor grad = grads[0];
return new Tensor[]
{
gen_nn_ops.avg_pool_grad(
array_ops.shape(op.inputs[0]),
grad,
op.get_attr_list<int>("ksize"),
op.get_attr_list<int>("strides"),
op.get_attr("padding").ToString(),
op.get_attr("data_format").ToString())
};
}

/// <summary>
/// Return the gradients for TopK.
/// </summary>


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow.Keras
List<ILayer> Layers { get; }
List<INode> InboundNodes { get; }
List<INode> OutboundNodes { get; }
Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null);
Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null);
List<IVariableV1> TrainableVariables { get; }
List<IVariableV1> TrainableWeights { get; }
List<IVariableV1> NonTrainableWeights { get; }


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -145,7 +145,7 @@ namespace Tensorflow
throw new NotImplementedException("_zero_state_tensors");
}

public Tensors Apply(Tensors inputs, Tensors state = null, bool is_training = false, IOptionalArgs? optional_args = null)
public Tensors Apply(Tensors inputs, Tensors state = null, bool? is_training = false, IOptionalArgs? optional_args = null)
{
throw new NotImplementedException();
}


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Layer.Apply.cs View File

@@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="state"></param>
/// <param name="training"></param>
/// <returns></returns>
public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null)
public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null)
{
if (callContext.Value == null)
callContext.Value = new CallContext();


+ 6
- 2
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -142,6 +142,7 @@ namespace Tensorflow.Keras.Engine
int verbose = 1,
List<ICallback> callbacks = null,
IDatasetV2 validation_data = null,
int validation_step = 10, // 间隔多少次会进行一次验证
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
@@ -164,11 +165,11 @@ namespace Tensorflow.Keras.Engine
});


return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data,
return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data,
train_step_func: train_step_function);
}

History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,
History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
{
stop_training = false;
@@ -207,6 +208,9 @@ namespace Tensorflow.Keras.Engine

if (validation_data != null)
{
if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0)
continue;

var val_logs = evaluate(validation_data);
foreach(var log in val_logs)
{


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -393,7 +393,7 @@ namespace Tensorflow.Keras.Layers.Rnn
}
}

public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool training = false, IOptionalArgs? optional_args = null)
public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool? training = false, IOptionalArgs? optional_args = null)
{
RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
if (optional_args is not null && rnn_optional_args is null)


+ 1
- 0
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs View File

@@ -58,6 +58,7 @@ namespace Tensorflow.Keras
if (shuffle)
dataset = dataset.shuffle(batch_size * 8, seed: seed);
dataset = dataset.batch(batch_size);
dataset.class_names = class_name_list;
return dataset;
}



+ 25
- 0
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs View File

@@ -6,6 +6,31 @@ namespace Tensorflow.Keras
{
public partial class Preprocessing
{

/// <summary>
/// 图片路径转为数据处理用的dataset
/// </summary>
/// <param name="image_paths"></param>
/// <param name="image_size"></param>
/// <param name="num_channels"></param>
/// <param name="interpolation">
/// 用于调整大小的插值方法。支持`bilinear`、`nearest`、`bicubic`、`area`、`lanczos3`、`lanczos5`、`gaussian`、`mitchellcubic`。
/// 默认为`'bilinear'`。
/// </param>
/// <returns></returns>
public IDatasetV2 paths_to_dataset(string[] image_paths,
Shape image_size,
int num_channels = 3,
int num_classes = 6,
string interpolation = "bilinear")
{
var path_ds = tf.data.Dataset.from_tensor_slices(image_paths);
var img_ds = path_ds.map(x => path_to_image(x, image_size, num_channels, interpolation));
var label_ds = dataset_utils.labels_to_dataset(new int[num_classes] , "", num_classes);

return img_ds;
}

public IDatasetV2 paths_and_labels_to_dataset(string[] image_paths,
Shape image_size,
int num_channels,


Loading…
Cancel
Save