|
23 | 23 | FLAGS=tf.app.flags.FLAGS
|
24 | 24 |
|
25 | 25 | tf.app.flags.DEFINE_string("data_path","./data/","path of traning data.")
|
26 |
| -tf.app.flags.DEFINE_string("training_data_file","./data/bert_train.txt","path of traning data.") #./data/cail2018_bi.json |
27 |
| -tf.app.flags.DEFINE_string("valid_data_file","./data/bert_train.txt","path of validation data.") |
28 |
| -tf.app.flags.DEFINE_string("test_data_file","./data/bert_test.txt","path of validation data.") |
| 26 | +tf.app.flags.DEFINE_string("training_data_file","./data/bert_train2.txt","path of traning data.") #./data/cail2018_bi.json |
| 27 | +tf.app.flags.DEFINE_string("valid_data_file","./data/bert_valid2.txt","path of validation data.") |
| 28 | +tf.app.flags.DEFINE_string("test_data_file","./data/bert_test2.txt","path of validation data.") |
29 | 29 | tf.app.flags.DEFINE_string("ckpt_dir","./checkpoint_lm/","checkpoint location for the model for restore from pre-train") #save to here, so make it easy to upload for test
|
30 | 30 | tf.app.flags.DEFINE_string("ckpt_dir_save","./checkpoint_lm_save/","checkpoint location for the model for save fine-tuning") #save to here, so make it easy to upload for test
|
31 | 31 |
|
|
35 | 35 | tf.app.flags.DEFINE_float("learning_rate",0.00001,"learning rate") #0.001
|
36 | 36 | tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size for training/evaluating.") # 32-->128
|
37 | 37 | tf.app.flags.DEFINE_integer("decay_steps", 10000, "how many steps before decay learning rate.") # 32-->128
|
38 |
| -tf.app.flags.DEFINE_float("decay_rate", 0.9, "Rate of decay for learning rate.") #0.65 |
| 38 | +tf.app.flags.DEFINE_float("decay_rate", 0.8, "Rate of decay for learning rate.") #0.65 |
39 | 39 | tf.app.flags.DEFINE_float("dropout_keep_prob", 0.9, "percentage to keep when using dropout.") #0.65
|
40 | 40 | tf.app.flags.DEFINE_integer("sequence_length",200,"max sentence length")#400
|
41 | 41 | tf.app.flags.DEFINE_integer("sequence_length_lm",10,"max sentence length for masked language model")
|
42 | 42 |
|
43 | 43 | tf.app.flags.DEFINE_boolean("is_training",True,"is training.true:tranining,false:testing/inference")
|
44 | 44 | tf.app.flags.DEFINE_boolean("is_fine_tuning",True,"is_finetuning.ture:this is fine-tuning stage")
|
45 | 45 |
|
46 |
| -tf.app.flags.DEFINE_integer("num_epochs",30,"number of epochs to run.") |
47 |
| -tf.app.flags.DEFINE_integer("process_num",3,"number of cpu used") |
| 46 | +tf.app.flags.DEFINE_integer("num_epochs",35,"number of epochs to run.") |
| 47 | +tf.app.flags.DEFINE_integer("process_num",35,"number of cpu used") |
48 | 48 |
|
49 | 49 | tf.app.flags.DEFINE_integer("validate_every", 1, "Validate every validate_every epochs.") #
|
50 | 50 | tf.app.flags.DEFINE_boolean("use_pretrained_embedding",False,"whether to use embedding or not.")#
|
51 | 51 | tf.app.flags.DEFINE_string("word2vec_model_path","./data/Tencent_AILab_ChineseEmbedding_100w.txt","word2vec's vocabulary and vectors") # data/sgns.target.word-word.dynwin5.thr10.neg5.dim300.iter5--->data/news_12g_baidubaike_20g_novel_90g_embedding_64.bin--->sgns.merge.char
|
52 |
| -tf.app.flags.DEFINE_boolean("test_mode",True,"whether it is test mode. if it is test mode, only small percentage of data will be used. test mode for test purpose.") |
| 52 | +tf.app.flags.DEFINE_boolean("test_mode",False,"whether it is test mode. if it is test mode, only small percentage of data will be used. test mode for test purpose.") |
53 | 53 |
|
54 | 54 | tf.app.flags.DEFINE_integer("d_model", 64, "dimension of model") # 512-->128
|
55 | 55 | tf.app.flags.DEFINE_integer("num_layer", 6, "number of layer")
|
@@ -81,7 +81,7 @@ def main(_):
|
81 | 81 | if os.path.exists(FLAGS.ckpt_dir+"checkpoint"):
|
82 | 82 | print("Restoring Variables from Checkpoint.")
|
83 | 83 | sess.run(tf.global_variables_initializer())
|
84 |
| - for i in range(2): #decay learning rate if necessary. |
| 84 | + for i in range(6): #decay learning rate if necessary. |
85 | 85 | print(i,"Going to decay learning rate by a factor of "+str(FLAGS.decay_rate))
|
86 | 86 | sess.run(model.learning_rate_decay_half_op)
|
87 | 87 | # restore those variables that names and shapes exists in your model from checkpoint. for detail check: https://gist.github.com/iganichev/d2d8a0b1abc6b15d4a07de83171163d4
|
@@ -110,7 +110,7 @@ def main(_):
|
110 | 110 | current_loss,lr,l2_loss,_=sess.run([model.loss_val,model.learning_rate,model.l2_loss,model.train_op],feed_dict)
|
111 | 111 | loss_total,counter=loss_total+current_loss,counter+1
|
112 | 112 | if counter %30==0:
|
113 |
| - print("Learning rate:%.5f\tLoss:%.3f\tCurrent_loss:%.3f\tL2_loss%.3f\t"%(lr,float(loss_total)/float(counter),current_loss,l2_loss)) |
| 113 | + print("Learning rate:%.7f\tLoss:%.3f\tCurrent_loss:%.3f\tL2_loss%.3f\t"%(lr,float(loss_total)/float(counter),current_loss,l2_loss)) |
114 | 114 | if start!=0 and start%(4000*FLAGS.batch_size)==0:
|
115 | 115 | loss_valid, f1_macro_valid, f1_micro_valid= do_eval(sess, model, valid,num_classes,label2index)
|
116 | 116 | f1_score_valid=((f1_macro_valid+f1_micro_valid)/2.0) #*100.0
|
|
0 commit comments