Skip to content

Commit

Permalink
Update to Tensorflow 2.16.1, drop support for Bidirectional, GRU, and…
Browse files Browse the repository at this point in the history
… LSTM
  • Loading branch information
Dobiasd authored Apr 16, 2024
1 parent 5637438 commit a60717c
Show file tree
Hide file tree
Showing 29 changed files with 196 additions and 2,721 deletions.
10 changes: 3 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ jobs:
sudo apt-get install libblas-dev liblapack-dev libatlas-base-dev gfortran
# python libs
sudo pip3 install --upgrade pip
sudo pip3 install numpy scipy h5py "tensorflow==2.15.0"
sudo pip3 install numpy scipy h5py "tensorflow==2.16.1"
echo "Python version:"
python3 --version
echo "Version numbers of TensorFlow and Keras:"
python3 -c "import tensorflow as tf; import tensorflow.keras; print(tf.__version__)"
# FunctionalPlus
Expand Down Expand Up @@ -54,12 +56,6 @@ jobs:
cmake .. -DFDEEP_BUILD_UNITTEST=ON
cmake --build . --target unittest --config Release --
cd ..
# run stateful tests
cd test/stateful_test
g++ -I../../include -std=c++14 -O3 stateful_recurrent_tests.cpp -o stateful_recurrent_tests_cpp
mkdir models
python3 stateful_recurrent_tests.py
cd ../..
formatting-check:
Expand Down
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,4 @@ _*
.mypy_cache
/experiments
.idea
test/stateful_test/models
CMakeUserPresets.json
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Would you like to build/train a model using Keras/Python? And would you like to

* `Add`, `Concatenate`, `Subtract`, `Multiply`, `Average`, `Maximum`, `Minimum`, `Dot`
* `AveragePooling1D/2D/3D`, `GlobalAveragePooling1D/2D/3D`
* `Bidirectional`, `TimeDistributed`, `GRU`, `LSTM`, `CuDNNGRU`, `CuDNNLSTM`
* `TimeDistributed`
* `Conv1D/2D`, `SeparableConv2D`, `DepthwiseConv2D`
* `Cropping1D/2D/3D`, `ZeroPadding1D/2D/3D`, `CenterCrop`
* `BatchNormalization`, `Dense`, `Flatten`, `Normalization`
Expand Down Expand Up @@ -81,6 +81,7 @@ Would you like to build/train a model using Keras/Python? And would you like to
`LSTMCell`, `Masking`,
`RepeatVector`, `RNN`, `SimpleRNN`,
`SimpleRNNCell`, `StackedRNNCells`, `StringLookup`, `TextVectorization`,
`Bidirectional`, `GRU`, `LSTM`, `CuDNNGRU`, `CuDNNLSTM`,
`ThresholdedReLU`, `Upsampling3D`, `temporal` models

Usage
Expand Down Expand Up @@ -139,7 +140,7 @@ Requirements and Installation

- A **C++14**-compatible compiler: Compilers from these versions on are fine: GCC 4.9, Clang 3.7 (libc++ 3.7) and Visual C++ 2015
- Python 3.7 or higher
- TensorFlow 2.15.0 (These are the tested versions, but somewhat older ones might work too.)
- TensorFlow 2.16.1 (These are the tested versions, but somewhat older ones might work too.)

Guides for different ways to install frugally-deep can be found in [`INSTALL.md`](INSTALL.md).

Expand Down
175 changes: 25 additions & 150 deletions include/fdeep/import_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
#include "fdeep/layers/average_layer.hpp"
#include "fdeep/layers/average_pooling_3d_layer.hpp"
#include "fdeep/layers/batch_normalization_layer.hpp"
#include "fdeep/layers/bidirectional_layer.hpp"
#include "fdeep/layers/category_encoding_layer.hpp"
#include "fdeep/layers/centercrop_layer.hpp"
#include "fdeep/layers/concatenate_layer.hpp"
Expand All @@ -49,14 +48,12 @@
#include "fdeep/layers/gelu_layer.hpp"
#include "fdeep/layers/global_average_pooling_3d_layer.hpp"
#include "fdeep/layers/global_max_pooling_3d_layer.hpp"
#include "fdeep/layers/gru_layer.hpp"
#include "fdeep/layers/hard_sigmoid_layer.hpp"
#include "fdeep/layers/input_layer.hpp"
#include "fdeep/layers/layer.hpp"
#include "fdeep/layers/layer_normalization_layer.hpp"
#include "fdeep/layers/leaky_relu_layer.hpp"
#include "fdeep/layers/linear_layer.hpp"
#include "fdeep/layers/lstm_layer.hpp"
#include "fdeep/layers/max_pooling_3d_layer.hpp"
#include "fdeep/layers/maximum_layer.hpp"
#include "fdeep/layers/minimum_layer.hpp"
Expand Down Expand Up @@ -319,7 +316,7 @@ namespace internal {
return create_vector<tensor_shape_variable>(create_tensor_shape_variable, data);
}

inline node_connection create_node_connection(const nlohmann::json& data)
inline node_connection create_node_connection_model_layer(const nlohmann::json& data)
{
assertion(data.is_array(), "invalid format for inbound node");
const std::string layer_id = data.front();
Expand All @@ -328,6 +325,16 @@ namespace internal {
return node_connection(layer_id, node_idx, tensor_idx);
}

inline node_connection create_node_connection(const nlohmann::json& args)
{
const std::vector<nlohmann::json> keras_history = args["config"]["keras_history"];
assertion(keras_history.size() >= 3, "invalid number of items in keras_history");
const std::string layer_id = keras_history[0];
const auto node_idx = create_size_t(keras_history[1]);
const auto tensor_idx = create_size_t(keras_history[2]);
return node_connection(layer_id, node_idx, tensor_idx);
}

using get_param_f = std::function<nlohmann::json(const std::string&, const std::string&)>;

using layer_creators = std::map<
Expand Down Expand Up @@ -375,10 +382,10 @@ namespace internal {
assertion(data["config"]["input_layers"].is_array(), "no input layers");

const auto inputs = create_vector<node_connection>(
create_node_connection, data["config"]["input_layers"]);
create_node_connection_model_layer, data["config"]["input_layers"]);

const auto outputs = create_vector<node_connection>(
create_node_connection, data["config"]["output_layers"]);
create_node_connection_model_layer, data["config"]["output_layers"]);

return std::make_shared<model_layer>(name, layers, inputs, outputs);
}
Expand Down Expand Up @@ -497,7 +504,7 @@ namespace internal {
{
assertion(data["inbound_nodes"].empty(),
"input layer is not allowed to have inbound nodes");
const auto input_shape = create_tensor_shape_variable_leading_null(data["config"]["batch_input_shape"]);
const auto input_shape = create_tensor_shape_variable_leading_null(data["config"]["batch_shape"]);
return std::make_shared<input_layer>(name, input_shape);
}

Expand Down Expand Up @@ -931,7 +938,7 @@ namespace internal {
const get_param_f&, const nlohmann::json& data,
const std::string& name)
{
float_type alpha = 1.0f;
float_type alpha = 0.3f;
if (json_obj_has_member(data, "config") && json_obj_has_member(data["config"], "alpha")) {
alpha = data["config"]["alpha"];
}
Expand Down Expand Up @@ -1085,6 +1092,7 @@ namespace internal {
{ "tanh", create_tanh_layer },
{ "sigmoid", create_sigmoid_layer },
{ "swish", create_swish_layer },
{ "silu", create_swish_layer },
{ "hard_sigmoid", create_hard_sigmoid_layer },
{ "relu", create_relu_layer },
{ "relu6", create_relu6_layer },
Expand Down Expand Up @@ -1121,37 +1129,21 @@ namespace internal {

inline node create_node(const nlohmann::json& inbound_nodes_data)
{
assertion(inbound_nodes_data.is_array(), "nodes need to be an array");
return node(create_vector<node_connection>(create_node_connection,
inbound_nodes_data));
}

inline nodes create_multi_head_attention_nodes(const std::vector<nlohmann::json> inbound_nodes_data)
{
assertion(inbound_nodes_data.size() == 1 && inbound_nodes_data.front().size() == 1,
"multi_head_attention needs to have exactly one primary inbound node; see https://stackoverflow.com/q/77400589/1866775");
const auto inbound_node_data = inbound_nodes_data.front().front();
const auto value = inbound_node_data[3]["value"];
if (json_obj_has_member(inbound_node_data[3], "key")) {
return {
node({ create_node_connection(inbound_node_data),
create_node_connection(value),
create_node_connection(inbound_node_data[3]["key"]) })
};
assertion(inbound_nodes_data["args"].is_array(), "node args need to be an array");
std::vector<nlohmann::json> args = inbound_nodes_data["args"];
if (args.front().is_array()) {
assertion(args.size() == 1, "invalid args format");
const std::vector<nlohmann::json> inner_args = args.front();
return node(fplus::transform(create_node_connection, inner_args));
} else {
return node(fplus::transform(create_node_connection, args));
}
return {
node({ create_node_connection(inbound_node_data),
create_node_connection(value) })
};
}

inline nodes create_nodes(const nlohmann::json& data)
{
assertion(data["inbound_nodes"].is_array(), "no inbound nodes");
const std::vector<nlohmann::json> inbound_nodes_data = data["inbound_nodes"];
if (data["class_name"] == "MultiHeadAttention") {
return create_multi_head_attention_nodes(inbound_nodes_data);
}
return fplus::transform(create_node, inbound_nodes_data);
}

Expand All @@ -1166,115 +1158,6 @@ namespace internal {
return std::make_shared<embedding_layer>(name, input_dim, output_dim, weights);
}

inline layer_ptr create_lstm_layer(const get_param_f& get_param,
const nlohmann::json& data,
const std::string& name)
{
auto&& config = data["config"];
const std::size_t units = config["units"];
const std::string unit_activation = json_object_get_activation_with_default(config, "tanh");
const std::string recurrent_activation = json_object_get(config,
"recurrent_activation",
data["class_name"] == "CuDNNLSTM"
? std::string("sigmoid")
: std::string("hard_sigmoid"));
const bool use_bias = json_object_get(config, "use_bias", true);

float_vec bias;
if (use_bias)
bias = decode_floats(get_param(name, "bias"));

const float_vec weights = decode_floats(get_param(name, "weights"));
const float_vec recurrent_weights = decode_floats(get_param(name, "recurrent_weights"));
const bool return_sequences = json_object_get(config, "return_sequences", false);
const bool return_state = json_object_get(config, "return_state", false);
const bool stateful = json_object_get(config, "stateful", false);

return std::make_shared<lstm_layer>(name, units, unit_activation,
recurrent_activation, use_bias,
return_sequences, return_state, stateful,
weights, recurrent_weights, bias);
}

inline layer_ptr create_gru_layer(const get_param_f& get_param,
const nlohmann::json& data,
const std::string& name)
{
auto&& config = data["config"];
const std::size_t units = config["units"];
const std::string unit_activation = json_object_get_activation_with_default(config, "tanh");
const std::string recurrent_activation = json_object_get(config,
"recurrent_activation",
data["class_name"] == "CuDNNGRU"
? std::string("sigmoid")
: std::string("hard_sigmoid"));

const bool use_bias = json_object_get(config, "use_bias", true);
const bool return_sequences = json_object_get(config, "return_sequences", false);
const bool return_state = json_object_get(config, "return_state", false);
const bool stateful = json_object_get(config, "stateful", false);

float_vec bias;
if (use_bias)
bias = decode_floats(get_param(name, "bias"));

const float_vec weights = decode_floats(get_param(name, "weights"));
const float_vec recurrent_weights = decode_floats(get_param(name, "recurrent_weights"));

bool reset_after = json_object_get(config,
"reset_after",
data["class_name"] == "CuDNNGRU");

return std::make_shared<gru_layer>(name, units, unit_activation,
recurrent_activation, use_bias, reset_after,
return_sequences, return_state, stateful,
weights, recurrent_weights, bias);
}

inline layer_ptr create_bidirectional_layer(const get_param_f& get_param,
const nlohmann::json& data,
const std::string& name)
{
const std::string merge_mode = data["config"]["merge_mode"];
auto&& layer = data["config"]["layer"];
auto&& layer_config = layer["config"];
const std::string wrapped_layer_type = layer["class_name"];
const std::size_t units = layer_config["units"];
const std::string unit_activation = json_object_get_activation_with_default(layer_config, "tanh");
const std::string recurrent_activation = json_object_get(layer_config,
"recurrent_activation",
wrapped_layer_type == "CuDNNGRU" || wrapped_layer_type == "CuDNNLSTM"
? std::string("sigmoid")
: std::string("hard_sigmoid"));
const bool use_bias = json_object_get(layer_config, "use_bias", true);

float_vec forward_bias;
float_vec backward_bias;

if (use_bias) {
forward_bias = decode_floats(get_param(name, "forward_bias"));
backward_bias = decode_floats(get_param(name, "backward_bias"));
}

const float_vec forward_weights = decode_floats(get_param(name, "forward_weights"));
const float_vec backward_weights = decode_floats(get_param(name, "backward_weights"));

const float_vec forward_recurrent_weights = decode_floats(get_param(name, "forward_recurrent_weights"));
const float_vec backward_recurrent_weights = decode_floats(get_param(name, "backward_recurrent_weights"));

const bool reset_after = json_object_get(layer_config,
"reset_after",
wrapped_layer_type == "CuDNNGRU");
const bool return_sequences = json_object_get(layer_config, "return_sequences", false);
const bool stateful = json_object_get(layer_config, "stateful", false);

return std::make_shared<bidirectional_layer>(name, merge_mode, units, unit_activation,
recurrent_activation, wrapped_layer_type,
use_bias, reset_after, return_sequences, stateful,
forward_weights, forward_recurrent_weights, forward_bias,
backward_weights, backward_recurrent_weights, backward_bias);
}

inline layer_ptr create_time_distributed_layer(const get_param_f& get_param,
const nlohmann::json& data,
const std::string& name,
Expand Down Expand Up @@ -1368,11 +1251,6 @@ namespace internal {
{ "Reshape", create_reshape_layer },
{ "Resizing", create_resizing_layer },
{ "Embedding", create_embedding_layer },
{ "LSTM", create_lstm_layer },
{ "CuDNNLSTM", create_lstm_layer },
{ "GRU", create_gru_layer },
{ "CuDNNGRU", create_gru_layer },
{ "Bidirectional", create_bidirectional_layer },
{ "Softmax", create_softmax_layer },
{ "Normalization", create_normalization_layer },
{ "CategoryEncoding", create_category_encoding_layer },
Expand Down Expand Up @@ -1403,10 +1281,7 @@ namespace internal {
fplus::get_from_map(creators, type))(
get_param, data, name);

if (type != "Activation" && json_obj_has_member(data["config"], "activation")
&& type != "GRU"
&& type != "LSTM"
&& type != "Bidirectional") {
if (type != "Activation" && json_obj_has_member(data["config"], "activation")) {
const std::string activation = get_activation_type(data["config"]["activation"]);
result->set_activation(
create_activation_layer_type_name(get_param, data,
Expand Down
Loading

0 comments on commit a60717c

Please sign in to comment.