Skip to content

Commit 93dda17

Browse files
committed
fix: Added support for training a multi-input model using a dataset.
1 parent f8b7bde commit 93dda17

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

+13-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,19 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
112112
Steps = data_handler.Inferredsteps
113113
});
114114

115-
return evaluate(data_handler, callbacks, is_val, test_function);
115+
Func<DataHandler, OwnedIterator, Dictionary<string, float>> testFunction;
116+
117+
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
118+
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
119+
{
120+
testFunction = test_step_multi_inputs_function;
121+
}
122+
else
123+
{
124+
testFunction = test_function;
125+
}
126+
127+
return evaluate(data_handler, callbacks, is_val, testFunction);
116128
}
117129

118130
/// <summary>

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+12-1
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,20 @@ public ICallback fit(IDatasetV2 dataset,
179179
StepsPerExecution = _steps_per_execution
180180
});
181181

182+
Func<DataHandler, OwnedIterator, Dictionary<string, float>> trainStepFunction;
183+
184+
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
185+
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
186+
{
187+
trainStepFunction = train_step_multi_inputs_function;
188+
}
189+
else
190+
{
191+
trainStepFunction = train_step_function;
192+
}
182193

183194
return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data,
184-
train_step_func: train_step_function);
195+
train_step_func: trainStepFunction);
185196
}
186197

187198
History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,

0 commit comments

Comments
 (0)