Skip to content

Commit 7e850c6

Browse files
BeacontownfcOceania2018
authored andcommitted
Add bert example
1 parent 2aeef9e commit 7e850c6

File tree

5 files changed

+963
-0
lines changed

5 files changed

+963
-0
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.IO;
4+
using System.Linq;
5+
using System.Text;
6+
using System.Threading.Tasks;
7+
using Tensorflow.Operations.Initializers;
8+
using static Tensorflow.KerasApi;
9+
using BERT;
10+
using Tensorflow.NumPy;
11+
using static Tensorflow.Binding;
12+
using static Tensorflow.Keras.Engine.InputSpec;
13+
14+
namespace TensorFlowNET.Examples
15+
{
16+
class BertClassification : SciSharpExample, IExample
17+
{
18+
int max_seq_len = 180;
19+
int batch_size = 4;
20+
int num_classes = 2;
21+
int epoch = 3;
22+
float learning_rate = (float)2e-5;
23+
string pretrained_weight_path = "./tf_model.h5";
24+
BertConfig config = new BertConfig();
25+
NDArray np_x_train;
26+
NDArray np_y_train;
27+
public ExampleConfig InitConfig()
28+
=> Config = new ExampleConfig
29+
{
30+
Name = "Bert for Classification",
31+
Enabled = true
32+
};
33+
34+
public override void PrepareData()
35+
{
36+
// tf.debugging.set_log_device_placement(true);
37+
Console.WriteLine("Preparing data...");
38+
string url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz";
39+
var dataset = keras.utils.get_file("aclImdb_v1.tar.gz", url,
40+
untar: true,
41+
cache_dir: Path.GetTempPath(),
42+
cache_subdir: "aclImdb_v1");
43+
var data_dir = Path.Combine(dataset, "aclImdb");
44+
var train_dir = Path.Combine(data_dir, "train");
45+
(int[,] x_train_neg, int[] y_train_neg) = IMDBDataPreProcessor.
46+
ProcessData(Path.Combine(train_dir, "neg"), max_seq_len, 0);
47+
(int[,] x_train_pos, int[] y_train_pos) = IMDBDataPreProcessor.
48+
ProcessData(Path.Combine(train_dir, "pos"), max_seq_len, 1);
49+
np_x_train = np.array(x_train_neg, dtype: tf.int32);
50+
np_y_train = np.array(y_train_neg, dtype: tf.int32);
51+
np_x_train = np.concatenate((np_x_train, np.array(x_train_pos, dtype: tf.int32)), 0);
52+
np_y_train = np.concatenate((np_y_train, np.array(y_train_pos, dtype: tf.int32)), 0);
53+
}
54+
55+
public bool Run()
56+
{
57+
var model = keras.Sequential();
58+
model.add(keras.layers.Input(max_seq_len, batch_size, dtype: tf.int32));
59+
model.add(new BertMainLayer(config));
60+
if(File.Exists(pretrained_weight_path)) model.load_weights(pretrained_weight_path);
61+
model.add(keras.layers.Dense(num_classes));
62+
model.compile(optimizer: keras.optimizers.AdamW(learning_rate, weight_decay: 0.01f, no_decay_params: new List<string> { "gamma", "beta" }),
63+
loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true), metrics: new[] { "acc" });
64+
model.summary();
65+
PrepareData();
66+
model.fit(np_x_train, np_y_train,
67+
batch_size: batch_size,
68+
epochs: epoch,
69+
shuffle: true,
70+
validation_split: 0.2f);
71+
return true;
72+
}
73+
}
74+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using System.Threading.Tasks;
6+
using Tensorflow.Keras.ArgsDefinition;
7+
8+
namespace BERT
9+
{
10+
class BertConfig : LayerArgs
11+
{
12+
public int vocab_size;
13+
public int hidden_size;
14+
public int num_hidden_layers;
15+
public int num_attention_heads;
16+
public int intermediate_size;
17+
public string hidden_act;
18+
public float hidden_dropout_prob;
19+
public float attention_probs_dropout_prob;
20+
public int max_position_embeddings;
21+
public int type_vocab_size;
22+
public float initializer_range;
23+
public float layer_norm_eps;
24+
public int pad_token_id;
25+
public string position_embedding_type;
26+
public BertConfig(int vocab_size = 30522,
27+
int hidden_size = 768,
28+
int num_hidden_layers = 12,
29+
int num_attention_heads = 12,
30+
int intermediate_size = 3072,
31+
string hidden_act = "gelu",
32+
double hidden_dropout_prob = 0.1,
33+
double attention_probs_dropout_prob = 0.1,
34+
int max_position_embeddings = 512,
35+
int type_vocab_size = 2,
36+
double initializer_range = 0.02,
37+
double layer_norm_eps = 1e-12,
38+
int pad_token_id = 0,
39+
string position_embedding_type = "absolute")
40+
{
41+
this.vocab_size = vocab_size;
42+
this.hidden_size = hidden_size;
43+
this.num_hidden_layers = num_hidden_layers;
44+
this.num_attention_heads = num_attention_heads;
45+
this.intermediate_size = intermediate_size;
46+
this.hidden_act = hidden_act;
47+
this.hidden_dropout_prob = (float)hidden_dropout_prob;
48+
this.attention_probs_dropout_prob = (float)attention_probs_dropout_prob;
49+
this.max_position_embeddings = max_position_embeddings;
50+
this.type_vocab_size = type_vocab_size;
51+
this.initializer_range = (float)initializer_range;
52+
this.layer_norm_eps = (float)layer_norm_eps;
53+
this.pad_token_id = pad_token_id;
54+
this.position_embedding_type = position_embedding_type;
55+
56+
}
57+
}
58+
}

0 commit comments

Comments
 (0)