diff --git a/README.md b/README.md index 1c5c744..1d39a46 100644 --- a/README.md +++ b/README.md @@ -44,35 +44,6 @@ All the above experiments were tested on GTX 1080 GPU with memory 8000MiB. This section provides a solution to visualize the BERT attention matrix. For more detail, you can check dictionary "BERT-GCN". -2020/5/11: add TextGCN and TextSAGE for text classification. - -2020/5/5: add GIN, GraphSAGE for graph classfication. - -2020/4/25: add GAN, GIN model, based on message passing methods. - -2020/4/23: add GCN model, based on message passing methods. - -2020/4/16:currently focusing on models of GNN in nlp, and trying to integrate some GNN models into fennlp. - -2020/4/2: add GPT2 model, could used parameters released by OpenAI (base,medium,large). -More detail reference dictionary "TG/EN/interactive.py" - -2020/3/26: add Bilstm+Attention example for classification - -2020/3/23: add RAdam optimizer. - -2020/3/19: add test example "albert_ner_train.py" "albert_ner_test.py" - -2020/3/16: add model for training sub word embedding based on bpe methods. -The trained embedding is used in TextCNN model for improve it's improvement. -See "tran_bpe_embeding.py" for more details. - -2020/3/8: add test example "run_tucker.py" for train TuckER on WN18. - -2020/3/3: add test example "tran_text_cnn.py" for train TextCNN model. - -2020/3/2: add test example "train_bert_classification.py" for text classification based on bert. - # Requirement * tensorflow-gpu>=2.0 * typeguard @@ -321,6 +292,7 @@ Same data split and parameters setting as proposed in this [paper](https://arxiv | ------- | ------- |------- |------- | |GCN |81.80 |79.50 | 71.20 | |GAN |83.00 | 79.00 | 72.30 | +|GAAE |82.4 |79.60 | 71.7 | * Graph Classfication diff --git a/nlpgnn/callbacks.py b/nlpgnn/callbacks.py index 1196b7c..7d06941 100644 --- a/nlpgnn/callbacks.py +++ b/nlpgnn/callbacks.py @@ -54,8 +54,8 @@ def __call__(self, current=None, model=None, moniter_loss=None, moniter_acc=None self.restore_best_weights = False if "both" in self.monitor: - assert moniter_acc!=None - assert moniter_loss!=None + assert moniter_acc != None + assert moniter_loss != None target = moniter_acc - moniter_loss else: target = current - self.min_delta @@ -73,3 +73,68 @@ def __call__(self, current=None, model=None, moniter_loss=None, moniter_acc=None print('Restoring model weights from the end of the best epoch.') model.set_weights(self.best_weights) return True + + +class EarlyStoppingScale: + def __init__(self, monitor="loss", + min_delta=0, + patience=0, + mode='auto', + baseline=None, + restore_scale=True, + verbose=1): + self.wait = 0 + self.min_delta = min_delta + self.mode = mode + self.patience = patience + self.verbose = verbose + self.monitor = monitor + self.baseline = baseline + self.restore_scale = restore_scale + if mode not in ['auto', 'min', 'max']: + print('EarlyStopping mode %s is unknown, ' + 'fallback to auto mode.', mode) + mode = 'auto' + if mode == 'min': + self.monitor_op = np.less + elif mode == 'max': + self.monitor_op = np.greater + else: + if 'acc' in self.monitor: + self.monitor_op = np.greater + elif 'both' in self.monitor: + self.monitor_op = np.greater + else: + self.monitor_op = np.less + + if self.monitor_op == np.greater: + self.min_delta *= 1 + else: + self.min_delta *= -1 + + if self.baseline is not None: + self.best = self.baseline + else: + self.best = np.Inf if self.monitor_op == np.less else -np.Inf + + def __call__(self, current=None, scale=None, moniter_loss=None, moniter_acc=None): + + if "both" in self.monitor: + assert moniter_acc != None + assert moniter_loss != None + target = moniter_acc - moniter_loss + else: + target = current - self.min_delta + if self.monitor_op(target, self.best): + self.best = target + self.wait = 0 + if self.restore_scale: + self.best_scale = scale + else: + self.wait += 1 + if self.wait >= self.patience: + print('Early stopping ...') + if self.verbose > 0: + print('Restoring scale from the end of the best epoch.') + return True, self.best_scale + return False, self.best_scale diff --git a/nlpgnn/gnn/GAAEConv.py b/nlpgnn/gnn/GAAEConv.py new file mode 100644 index 0000000..7f07e87 --- /dev/null +++ b/nlpgnn/gnn/GAAEConv.py @@ -0,0 +1,99 @@ +#! usr/bin/env python3 +# -*- coding:utf-8 -*- +""" +@Author:Kaiyin Zhou +""" + +import tensorflow as tf + +from nlpgnn.gnn.messagepassing import MessagePassing +from nlpgnn.gnn.utils import GNNInput, masksoftmax + + +class GraphAttentionAutoEncoder(MessagePassing): + def __init__(self, + out_features, + heads=1, + dropout_rate=0., + use_bias=False, + kernel_initializer='glorot_uniform', + bias_initializer='zeros', + regularizer=5e-4, + concat=True, + **kwargs): + super(GraphAttentionAutoEncoder, self).__init__(aggr="sum", **kwargs) + self.use_bias = use_bias + self.out_features = out_features + self.heads = heads + self.dropout_rate = dropout_rate + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.regularizer = regularizer + self.concat = concat + + def build(self, input_shapes): + node_embedding_shapes = input_shapes.node_embeddings + # adjacency_list_shapes = input_shapes.adjacency_lists + in_features = node_embedding_shapes[-1] + + self.att = self.add_weight( + shape=(1, self.heads, 2 * self.out_features), + initializer=self.kernel_initializer, + name='att', + ) + + if self.use_bias and self.concat: + self.bias = self.add_weight( + shape=(self.heads * self.out_features,), + initializer=self.bias_initializer, + name='b', + ) + elif self.use_bias and not self.concat: + self.bias = self.add_weight( + shape=(self.out_features,), + initializer=self.bias_initializer, + name='b', + ) + + self.drop1 = tf.keras.layers.Dropout(self.dropout_rate) + self.drop2 = tf.keras.layers.Dropout(self.dropout_rate) + self.built = True + + def message_function(self, edge_source_states, edge_source, # x_j source + edge_target_states, edge_target, # x_i target + num_incoming_to_node_per_message, # degree target + num_outing_to_node_per_message, # degree source + edge_type_idx, training): + """ + :param edge_source_states: [M,H] + :param edge_target_states: [M,H] + :param num_incoming_to_node_per_message:[M] + :param edge_type_idx: + :param training: + :return: + """ + # 计算注意力系数 + alpha = tf.concat([edge_target_states, edge_source_states], -1) * self.att #[M,heads,2D] + alpha = tf.reduce_sum(alpha, -1) # [M,Head] + alpha = tf.math.sigmoid(alpha) + alpha = masksoftmax(alpha, edge_target) + # alpha = self.drop1(alpha, training=training) + # edge_source_states = self.drop2(edge_source_states, training=training) + # messages = tf.math.sigmoid(edge_source_states) * tf.reshape(alpha, [-1, self.heads, 1]) + messages = edge_source_states * tf.reshape(alpha, [-1, self.heads, 1]) + return messages + + def call(self, inputs, weight, transpose_b, training): + adjacency_lists = inputs.adjacency_lists + node_embeddings = inputs.node_embeddings + node_embeddings = tf.linalg.matmul(node_embeddings, weight, transpose_b=transpose_b) + + node_embeddings = tf.reshape(node_embeddings, [node_embeddings.shape[0], self.heads, -1]) + aggr_out = self.propagate(GNNInput(node_embeddings, adjacency_lists), training) + if self.concat is True: + aggr_out = tf.reshape(aggr_out, [-1, self.heads * self.out_features]) + else: + aggr_out = tf.reduce_mean(aggr_out, 1) + if self.use_bias: + aggr_out += self.bias + return aggr_out diff --git a/nlpgnn/gnn/GATConv.py b/nlpgnn/gnn/GATConv.py index a70ee01..45b9871 100644 --- a/nlpgnn/gnn/GATConv.py +++ b/nlpgnn/gnn/GATConv.py @@ -83,7 +83,7 @@ def message_function(self, edge_source_states, edge_source, # x_j source edge_target_states = tf.reshape(edge_target_states, [-1, self.heads, self.out_features]) # [M,Head,dim] # self.att=[1,heads,2*D] # [M,keads,2D] * [1,heads,2D] - alpha = tf.concat([edge_target_states, edge_source_states], -1) * self.att + alpha = tf.concat([edge_target_states, edge_source_states], -1) * self.att #[M,heads,2D] alpha = tf.reduce_mean(alpha, -1) alpha = tf.nn.leaky_relu(alpha, self.negative_slope) alpha = masksoftmax(alpha, edge_target) # here not provide nodes num, because we have add self loop at the beginning. @@ -96,6 +96,7 @@ def call(self, inputs, training): adjacency_lists = inputs.adjacency_lists node_embeddings = inputs.node_embeddings node_embeddings = tf.linalg.matmul(node_embeddings, self.weight) + aggr_out = self.propagate(GNNInput(node_embeddings, adjacency_lists), training) if self.concat is True: aggr_out = tf.reshape(aggr_out, [-1, self.heads * self.out_features]) diff --git a/nlpgnn/models/GAAE.py b/nlpgnn/models/GAAE.py new file mode 100644 index 0000000..bcf5019 --- /dev/null +++ b/nlpgnn/models/GAAE.py @@ -0,0 +1,56 @@ +#! encoding="utf-8" +import tensorflow as tf + +from nlpgnn.gnn.GAAEConv import GraphAttentionAutoEncoder +from nlpgnn.gnn.utils import GNNInput + + +class GAAELayer(tf.keras.layers.Layer): + def __init__(self, hidden_dim=16, num_layers=2, heads=1, **kwargs): + super(GAAELayer, self).__init__(**kwargs) + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.heads = heads + + def build(self, input_shape): + input_dim = input_shape[-1] + + self.weight = [] + self.weight.append(self.add_weight( + shape=(input_dim, self.heads * self.hidden_dim), + name='wt', + )) + for i in range(self.num_layers - 1): + self.weight.append(self.add_weight( + shape=(self.hidden_dim, self.heads * self.hidden_dim), + name='wt', + )) + self.encoder_layers = [] + self.decoder_layers = [] + for layer in range(self.num_layers - 1): + self.encoder_layers.append(GraphAttentionAutoEncoder(self.hidden_dim, heads=self.heads)) + self.decoder_layers.append(GraphAttentionAutoEncoder(self.hidden_dim, heads=self.heads)) + self.encoder_layers.append(GraphAttentionAutoEncoder(self.hidden_dim, heads=self.heads)) + self.decoder_layers.append(GraphAttentionAutoEncoder(input_dim, heads=self.heads)) + + def encoder(self, node_embeddings, adjacency_lists, training): + for layer in range(self.num_layers): + node_embeddings = self.encoder_layers[layer](GNNInput(node_embeddings, adjacency_lists), self.weight[layer], + False, training) + return node_embeddings + + def decoder(self, hidden_embeddings, adjacency_lists, training): + for layer in range(self.num_layers): + hidden_embeddings = self.decoder_layers[layer](GNNInput(hidden_embeddings, adjacency_lists), + self.weight[-(layer+1)], + True, + training) + return hidden_embeddings + + def call(self, node_embeddings, adjacency_lists, training=True): + hidden_embeddings = self.encoder(node_embeddings, adjacency_lists, training) + reconstruct_embedding = self.decoder(hidden_embeddings, adjacency_lists, training) + return hidden_embeddings, reconstruct_embedding + + def predict(self, node_embeddings, adjacency_lists, training=False): + return self(node_embeddings, adjacency_lists, training) diff --git a/nlpgnn/models/__init__.py b/nlpgnn/models/__init__.py index 4216eb7..cf5cdcb 100644 --- a/nlpgnn/models/__init__.py +++ b/nlpgnn/models/__init__.py @@ -21,4 +21,5 @@ from .TextCNN import * from .tucker import * from .GraphSage import * -from .TextGCN2019 import * \ No newline at end of file +from .TextGCN2019 import * +from .GAAE import * \ No newline at end of file diff --git a/setup.py b/setup.py index 69b7a2a..eefb8ba 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ EMAIL = 'zhoukaiyinhzau@gmail.com' AUTHOR = 'Kaiyin Zhou' REQUIRES_PYTHON = '>=3.6.0' -VERSION = '0.0.7' +VERSION = '0.0.0' REQUIRED = [ 'typeguard', diff --git a/tests/GNN/auto_encoder/GAAE.py b/tests/GNN/auto_encoder/GAAE.py new file mode 100644 index 0000000..253ab8b --- /dev/null +++ b/tests/GNN/auto_encoder/GAAE.py @@ -0,0 +1,102 @@ +#! encoding:utf-8 +import time +import numpy as np +import tensorflow as tf +from nlpgnn.datas import Planetoid +from nlpgnn.models import GAAELayer +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import classification_report +from nlpgnn.callbacks import EarlyStoppingScale + +hidden_dim = 300 +drop_rate = 0.5 +epoch = 100 + +# cora, pubmed, citeseer +data = Planetoid(name="citeseer", loop=True, norm=True) + +features, adj, y_train, y_val, y_test, train_mask, val_mask, test_mask = data.load() + +train_index = np.argwhere(train_mask == 1).reshape([-1]).tolist() +valid_index = np.argwhere(val_mask == 1).reshape([-1]).tolist() +test_index = np.argwhere(test_mask == 1).reshape([-1]).tolist() + + +class GAAE(tf.keras.Model): + def __init__(self, hidden_dim, lamb=1, **kwargs): + super(GAAE, self).__init__(**kwargs) + self.lamb = lamb + self.hidden_dim = hidden_dim + self.model = GAAELayer(hidden_dim, num_layers=2) + + def call(self, node_embeddings, adjacency_lists, training=True): + edge_sources = adjacency_lists[0][:, 0] # [M] + edge_targets = adjacency_lists[0][:, 1] # [M] + hidden_embeddings, reconstruct_embedding = self.model(node_embeddings, adjacency_lists, training) + # The reconstruction loss of node features + features_loss = tf.sqrt(tf.reduce_sum(tf.reduce_sum(tf.pow(node_embeddings - reconstruct_embedding, 2)))) + # The reconstruction loss of the graph structure + s_emb = tf.nn.embedding_lookup(hidden_embeddings, edge_sources) + r_emb = tf.nn.embedding_lookup(hidden_embeddings, edge_targets) + structure_loss = -tf.math.log(tf.sigmoid(tf.reduce_sum(s_emb * r_emb, axis=-1))) + structure_loss = tf.reduce_sum(structure_loss) + loss = features_loss + self.lamb * structure_loss + return loss, hidden_embeddings + + def predict(self, node_embeddings, adjacency_lists, training=False): + return self(node_embeddings, adjacency_lists, training) + + +model = GAAE(hidden_dim) + +optimizer = tf.keras.optimizers.Adam(0.1) + +# --------------------------------------------------------- +# For train +stop_monitor = EarlyStoppingScale(monitor="acc", patience=20, restore_scale=True) + +hidden_embeddings=0 +test_features = 0 +test_y = 0 +loss_v = 0 +for p in range(epoch): + t = time.time() + with tf.GradientTape() as tape: + loss, _ = model(features, adj, training=True) + if p == 0: model.summary() + grads = tape.gradient(loss, model.variables) + optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables)) + + loss_v, hidden_embeddings = model.predict(features, adj) + hidden_embeddings = hidden_embeddings.numpy() + + train_features = hidden_embeddings[train_index] + train_y = np.argmax(y_train[train_index], -1) + + valid_features = hidden_embeddings[valid_index] + valid_y = np.argmax(y_val[valid_index], -1) + + clf = LogisticRegression(solver='lbfgs', multi_class='ovr', max_iter=500) + clf.fit(train_features, train_y) + + predict_y = clf.predict(valid_features) + report_v = classification_report(valid_y, predict_y, digits=4, output_dict=True) + acc = report_v["accuracy"] + print("EPOCH {:.0f} loss {:.4f} ACC {:.4f} Time {:.4f}".format(p,loss_v, acc, time.time()-t)) + check, hidden_embeddings = stop_monitor(acc, scale=hidden_embeddings) + if check: + break + +train_features = hidden_embeddings[train_index] +train_y = np.argmax(y_train[train_index], -1) + +test_features = hidden_embeddings[test_index] +test_y = np.argmax(y_test[test_index], -1) + +clf = LogisticRegression(solver='lbfgs', multi_class='ovr', max_iter=500) +clf.fit(train_features, train_y) +predict_y = clf.predict(test_features) + +report = classification_report(test_y, predict_y, digits=4, output_dict=True) +acc = report["accuracy"] +print("Test: Loss {:.5f} Acc {:.5f}".format(loss_v, acc)) diff --git a/tests/GNN/nodes_graph_classfication/train_gcn.py b/tests/GNN/nodes_graph_classfication/train_gcn.py index 264f7fe..dcdd2ed 100644 --- a/tests/GNN/nodes_graph_classfication/train_gcn.py +++ b/tests/GNN/nodes_graph_classfication/train_gcn.py @@ -1,5 +1,6 @@ #! encoding:utf-8 import time +import numpy as np import tensorflow as tf from nlpgnn.datas import Planetoid from nlpgnn.metrics import Losess, Metric @@ -20,6 +21,8 @@ features, adj, y_train, y_val, y_test, train_mask, val_mask, test_mask = data.load() + + model = GCNLayer(hidden_dim, num_class, drop_rate) optimizer = tf.keras.optimizers.Adam(0.01)