|
| 1 | +using System; |
| 2 | +using System.Collections.Generic; |
| 3 | +using System.IO; |
| 4 | +using System.Linq; |
| 5 | +using System.Text; |
| 6 | +using System.Threading.Tasks; |
| 7 | +using Tensorflow.Operations.Initializers; |
| 8 | +using static Tensorflow.KerasApi; |
| 9 | +using BERT; |
| 10 | +using Tensorflow.NumPy; |
| 11 | +using static Tensorflow.Binding; |
| 12 | +using static Tensorflow.Keras.Engine.InputSpec; |
| 13 | + |
| 14 | +namespace TensorFlowNET.Examples |
| 15 | +{ |
| 16 | + class BertClassification : SciSharpExample, IExample |
| 17 | + { |
| 18 | + int max_seq_len = 180; |
| 19 | + int batch_size = 4; |
| 20 | + int num_classes = 2; |
| 21 | + int epoch = 3; |
| 22 | + float learning_rate = (float)2e-5; |
| 23 | + string pretrained_weight_path = "./tf_model.h5"; |
| 24 | + BertConfig config = new BertConfig(); |
| 25 | + NDArray np_x_train; |
| 26 | + NDArray np_y_train; |
| 27 | + public ExampleConfig InitConfig() |
| 28 | + => Config = new ExampleConfig |
| 29 | + { |
| 30 | + Name = "Bert for Classification", |
| 31 | + Enabled = true |
| 32 | + }; |
| 33 | + |
| 34 | + public override void PrepareData() |
| 35 | + { |
| 36 | + // tf.debugging.set_log_device_placement(true); |
| 37 | + Console.WriteLine("Preparing data..."); |
| 38 | + string url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"; |
| 39 | + var dataset = keras.utils.get_file("aclImdb_v1.tar.gz", url, |
| 40 | + untar: true, |
| 41 | + cache_dir: Path.GetTempPath(), |
| 42 | + cache_subdir: "aclImdb_v1"); |
| 43 | + var data_dir = Path.Combine(dataset, "aclImdb"); |
| 44 | + var train_dir = Path.Combine(data_dir, "train"); |
| 45 | + (int[,] x_train_neg, int[] y_train_neg) = IMDBDataPreProcessor. |
| 46 | + ProcessData(Path.Combine(train_dir, "neg"), max_seq_len, 0); |
| 47 | + (int[,] x_train_pos, int[] y_train_pos) = IMDBDataPreProcessor. |
| 48 | + ProcessData(Path.Combine(train_dir, "pos"), max_seq_len, 1); |
| 49 | + np_x_train = np.array(x_train_neg, dtype: tf.int32); |
| 50 | + np_y_train = np.array(y_train_neg, dtype: tf.int32); |
| 51 | + np_x_train = np.concatenate((np_x_train, np.array(x_train_pos, dtype: tf.int32)), 0); |
| 52 | + np_y_train = np.concatenate((np_y_train, np.array(y_train_pos, dtype: tf.int32)), 0); |
| 53 | + } |
| 54 | + |
| 55 | + public bool Run() |
| 56 | + { |
| 57 | + var model = keras.Sequential(); |
| 58 | + model.add(keras.layers.Input(max_seq_len, batch_size, dtype: tf.int32)); |
| 59 | + model.add(new BertMainLayer(config)); |
| 60 | + if(File.Exists(pretrained_weight_path)) model.load_weights(pretrained_weight_path); |
| 61 | + model.add(keras.layers.Dense(num_classes)); |
| 62 | + model.compile(optimizer: keras.optimizers.AdamW(learning_rate, weight_decay: 0.01f, no_decay_params: new List<string> { "gamma", "beta" }), |
| 63 | + loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true), metrics: new[] { "acc" }); |
| 64 | + model.summary(); |
| 65 | + PrepareData(); |
| 66 | + model.fit(np_x_train, np_y_train, |
| 67 | + batch_size: batch_size, |
| 68 | + epochs: epoch, |
| 69 | + shuffle: true, |
| 70 | + validation_split: 0.2f); |
| 71 | + return true; |
| 72 | + } |
| 73 | + } |
| 74 | +} |
0 commit comments