Skip to content

Commit d4f1c34

Browse files
committed
fix _apply_dense for Optimizer.
1 parent d3b681e commit d4f1c34

31 files changed

+299
-108
lines changed

src/TensorFlowNET.Core/APIs/tf.layers.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ public Tensor dense(Tensor inputs,
193193
Name = name
194194
});
195195

196-
throw new NotImplementedException("");
197-
//return layer.apply(inputs).Item1;
196+
return layer.Apply(inputs);
198197
}
199198

200199
/// <summary>

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ public Tensor dropout(Tensor x, Tensor keep_prob = null, Tensor noise_shape = nu
6666
Tensor keep = null;
6767
if (keep_prob != null)
6868
keep = 1.0f - keep_prob;
69-
70-
return nn_ops.dropout_v2(x, rate: rate.Value, noise_shape: noise_shape, seed: seed, name: name);
69+
var rate_tensor = rate.HasValue ? tf.constant(rate.Value) : keep;
70+
return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name);
7171
}
7272

7373
/// <summary>

src/TensorFlowNET.Core/Framework/meta_graph.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ public static (Dictionary<string, IVariableV1>, ITensorOrOperation[]) import_sco
150150
var variables = graph.get_collection<IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES,
151151
scope: scope_to_prepend_to_names);
152152
var var_list = new Dictionary<string, IVariableV1>();
153-
// variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v);
153+
variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v);
154154

155155
return (var_list, imported_return_elements);
156156
}
@@ -277,6 +277,11 @@ private static void add_collection_def(MetaGraphDef meta_graph_def,
277277
var proto = x_ref_var.to_proto(export_scope);
278278
col_def.BytesList.Value.Add(proto.ToByteString());
279279
}
280+
else if(x is ResourceVariable x_res_var)
281+
{
282+
var proto = x_res_var.to_proto(export_scope);
283+
col_def.BytesList.Value.Add(proto.ToByteString());
284+
}
280285
}
281286
break;
282287
case List<RefVariable> collection_list:

src/TensorFlowNET.Core/Functions/c_api.function.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,23 @@ public partial class c_api
3131
/// <param name="output_func_def"></param>
3232
/// <param name="status"></param>
3333
[DllImport(TensorFlowLibName)]
34-
public static extern void TF_FunctionToFunctionDef(IntPtr func, IntPtr output_func_def, SafeStatusHandle status);
34+
public static extern void TF_FunctionToFunctionDef(IntPtr func, SafeBufferHandle output_func_def, SafeStatusHandle status);
3535

36+
[DllImport(TensorFlowLibName)]
37+
public static extern IntPtr TF_GraphToFunction(IntPtr fn_body, string fn_name,
38+
bool append_hash_to_fn_name,
39+
int num_opers, IntPtr[] opers,
40+
int ninputs, TF_Output[] inputs,
41+
int noutputs, TF_Output[] outputs,
42+
IntPtr output_names,
43+
IntPtr opts,
44+
string description,
45+
SafeStatusHandle status);
46+
47+
[DllImport(TensorFlowLibName)]
48+
public static extern IntPtr TF_FunctionName(IntPtr func);
3649

50+
[DllImport(TensorFlowLibName)]
51+
public static extern void TF_GraphCopyFunction(IntPtr g, IntPtr func, IntPtr grad, SafeStatusHandle status);
3752
}
3853
}

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,9 @@ public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
327327
var output_shape = op.outputs[0]._shape_tuple();
328328

329329
Tensor result, factor_tensor;
330-
if(input_shape != null &&
331-
output_shape != null)
330+
if(tf.executing_eagerly()
331+
&& input_shape != null
332+
&& output_shape != null)
332333
{
333334
var input_size = np.prod(input_shape);
334335
var output_size = np.prod(output_shape);
@@ -339,11 +340,7 @@ public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
339340
{
340341
var input_shape_tensor = array_ops.shape(op.inputs[0]);
341342
var output_shape_tensor = array_ops.shape(op.outputs[0]);
342-
var factor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor));
343-
throw new NotImplementedException("");
344-
#pragma warning disable CS0162 // Unreachable code detected
345-
factor_tensor = null;
346-
#pragma warning restore CS0162 // Unreachable code detected
343+
factor_tensor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor));
347344
}
348345

349346
result = math_ops.truediv(sum_grad, math_ops.cast(factor_tensor, sum_grad.dtype));

src/TensorFlowNET.Core/Gradients/nn_grad.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ public static Tensor[] _SparseSoftmaxCrossEntropyWithLogitsGrad(Operation op, Te
128128
[RegisterGradient("Conv2D")]
129129
public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads)
130130
{
131-
var dilations = op.get_attr<int[]>("dilations");
132-
var strides = op.get_attr<int[]>("strides");
131+
var dilations = op.get_attr_list<int>("dilations");
132+
var strides = op.get_attr_list<int>("strides");
133133
var padding = op.get_attr<string>("padding");
134-
var explicit_paddings = op.get_attr<int[]>("explicit_paddings");
134+
var explicit_paddings = op.get_attr_list<int>("explicit_paddings");
135135
var use_cudnn_on_gpu = op.get_attr<bool>("use_cudnn_on_gpu");
136136
var data_format = op.get_attr<string>("data_format");
137137
var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] });
@@ -287,8 +287,8 @@ public static Tensor[] _MaxPoolGrad(Operation op, Tensor[] grads)
287287
op.inputs[0],
288288
op.outputs[0],
289289
grad,
290-
op.get_attr("ksize") as int[],
291-
op.get_attr("strides") as int[],
290+
op.get_attr_list<int>("ksize"),
291+
op.get_attr_list<int>("strides"),
292292
padding: op.get_attr("padding").ToString(),
293293
data_format: op.get_attr("data_format").ToString())
294294
};

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -293,12 +293,6 @@ public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes
293293

294294
_create_op_helper(op, compute_device);
295295

296-
/*Console.Write($"create_op: {op_type} '{node_def.Name}'");
297-
Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}");
298-
Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}");
299-
Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}");
300-
Console.WriteLine();*/
301-
302296
return op;
303297
}
304298

src/TensorFlowNET.Core/Graphs/c_api.graph.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ public partial class c_api
139139
/// <param name="status">TF_Status*</param>
140140
[DllImport(TensorFlowLibName)]
141141
public static extern void TF_GraphToGraphDef(IntPtr graph, SafeBufferHandle output_graph_def, SafeStatusHandle status);
142-
142+
143143
/// <summary>
144144
/// Returns the number of dimensions of the Tensor referenced by `output`
145145
/// in `graph`.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class TensorLikeDataAdapterArgs
8+
{
9+
public Tensor X { get; set; }
10+
public Tensor Y { get; set; }
11+
public int BatchSize { get; set; }
12+
public int Steps { get; set; }
13+
public int Epochs { get; set; }
14+
public bool Shuffle { get; set; }
15+
}
16+
}

src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ public class DataHandler
2727

2828
public DataHandler(DataHandlerArgs args)
2929
{
30+
this.args = args;
3031

32+
var adapter_cls = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs { });
3133
}
3234
}
3335
}

0 commit comments

Comments
 (0)