Skip to content

Commit

Permalink
2020/5/11
Browse files Browse the repository at this point in the history
  • Loading branch information
kyzhouhzau committed May 28, 2020
1 parent 95aa331 commit 807ba46
Show file tree
Hide file tree
Showing 9 changed files with 333 additions and 34 deletions.
30 changes: 1 addition & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,35 +44,6 @@ All the above experiments were tested on GTX 1080 GPU with memory 8000MiB.
This section provides a solution to visualize the BERT attention matrix.
For more detail, you can check dictionary "BERT-GCN".

2020/5/11: add TextGCN and TextSAGE for text classification.

2020/5/5: add GIN, GraphSAGE for graph classfication.

2020/4/25: add GAN, GIN model, based on message passing methods.

2020/4/23: add GCN model, based on message passing methods.

2020/4/16:currently focusing on models of GNN in nlp, and trying to integrate some GNN models into fennlp.

2020/4/2: add GPT2 model, could used parameters released by OpenAI (base,medium,large).
More detail reference dictionary "TG/EN/interactive.py"

2020/3/26: add Bilstm+Attention example for classification

2020/3/23: add RAdam optimizer.

2020/3/19: add test example "albert_ner_train.py" "albert_ner_test.py"

2020/3/16: add model for training sub word embedding based on bpe methods.
The trained embedding is used in TextCNN model for improve it's improvement.
See "tran_bpe_embeding.py" for more details.

2020/3/8: add test example "run_tucker.py" for train TuckER on WN18.

2020/3/3: add test example "tran_text_cnn.py" for train TextCNN model.

2020/3/2: add test example "train_bert_classification.py" for text classification based on bert.

# Requirement
* tensorflow-gpu>=2.0
* typeguard
Expand Down Expand Up @@ -321,6 +292,7 @@ Same data split and parameters setting as proposed in this [paper](https://arxiv
| ------- | ------- |------- |------- |
|GCN |81.80 |79.50 | 71.20 |
|GAN |83.00 | 79.00 | 72.30 |
|GAAE |82.4 |79.60 | 71.7 |

* Graph Classfication

Expand Down
69 changes: 67 additions & 2 deletions nlpgnn/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def __call__(self, current=None, model=None, moniter_loss=None, moniter_acc=None
self.restore_best_weights = False

if "both" in self.monitor:
assert moniter_acc!=None
assert moniter_loss!=None
assert moniter_acc != None
assert moniter_loss != None
target = moniter_acc - moniter_loss
else:
target = current - self.min_delta
Expand All @@ -73,3 +73,68 @@ def __call__(self, current=None, model=None, moniter_loss=None, moniter_acc=None
print('Restoring model weights from the end of the best epoch.')
model.set_weights(self.best_weights)
return True


class EarlyStoppingScale:
def __init__(self, monitor="loss",
min_delta=0,
patience=0,
mode='auto',
baseline=None,
restore_scale=True,
verbose=1):
self.wait = 0
self.min_delta = min_delta
self.mode = mode
self.patience = patience
self.verbose = verbose
self.monitor = monitor
self.baseline = baseline
self.restore_scale = restore_scale
if mode not in ['auto', 'min', 'max']:
print('EarlyStopping mode %s is unknown, '
'fallback to auto mode.', mode)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
elif 'both' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less

if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1

if self.baseline is not None:
self.best = self.baseline
else:
self.best = np.Inf if self.monitor_op == np.less else -np.Inf

def __call__(self, current=None, scale=None, moniter_loss=None, moniter_acc=None):

if "both" in self.monitor:
assert moniter_acc != None
assert moniter_loss != None
target = moniter_acc - moniter_loss
else:
target = current - self.min_delta
if self.monitor_op(target, self.best):
self.best = target
self.wait = 0
if self.restore_scale:
self.best_scale = scale
else:
self.wait += 1
if self.wait >= self.patience:
print('Early stopping ...')
if self.verbose > 0:
print('Restoring scale from the end of the best epoch.')
return True, self.best_scale
return False, self.best_scale
99 changes: 99 additions & 0 deletions nlpgnn/gnn/GAAEConv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#! usr/bin/env python3
# -*- coding:utf-8 -*-
"""
@Author:Kaiyin Zhou
"""

import tensorflow as tf

from nlpgnn.gnn.messagepassing import MessagePassing
from nlpgnn.gnn.utils import GNNInput, masksoftmax


class GraphAttentionAutoEncoder(MessagePassing):
def __init__(self,
out_features,
heads=1,
dropout_rate=0.,
use_bias=False,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
regularizer=5e-4,
concat=True,
**kwargs):
super(GraphAttentionAutoEncoder, self).__init__(aggr="sum", **kwargs)
self.use_bias = use_bias
self.out_features = out_features
self.heads = heads
self.dropout_rate = dropout_rate
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.regularizer = regularizer
self.concat = concat

def build(self, input_shapes):
node_embedding_shapes = input_shapes.node_embeddings
# adjacency_list_shapes = input_shapes.adjacency_lists
in_features = node_embedding_shapes[-1]

self.att = self.add_weight(
shape=(1, self.heads, 2 * self.out_features),
initializer=self.kernel_initializer,
name='att',
)

if self.use_bias and self.concat:
self.bias = self.add_weight(
shape=(self.heads * self.out_features,),
initializer=self.bias_initializer,
name='b',
)
elif self.use_bias and not self.concat:
self.bias = self.add_weight(
shape=(self.out_features,),
initializer=self.bias_initializer,
name='b',
)

self.drop1 = tf.keras.layers.Dropout(self.dropout_rate)
self.drop2 = tf.keras.layers.Dropout(self.dropout_rate)
self.built = True

def message_function(self, edge_source_states, edge_source, # x_j source
edge_target_states, edge_target, # x_i target
num_incoming_to_node_per_message, # degree target
num_outing_to_node_per_message, # degree source
edge_type_idx, training):
"""
:param edge_source_states: [M,H]
:param edge_target_states: [M,H]
:param num_incoming_to_node_per_message:[M]
:param edge_type_idx:
:param training:
:return:
"""
# 计算注意力系数
alpha = tf.concat([edge_target_states, edge_source_states], -1) * self.att #[M,heads,2D]
alpha = tf.reduce_sum(alpha, -1) # [M,Head]
alpha = tf.math.sigmoid(alpha)
alpha = masksoftmax(alpha, edge_target)
# alpha = self.drop1(alpha, training=training)
# edge_source_states = self.drop2(edge_source_states, training=training)
# messages = tf.math.sigmoid(edge_source_states) * tf.reshape(alpha, [-1, self.heads, 1])
messages = edge_source_states * tf.reshape(alpha, [-1, self.heads, 1])
return messages

def call(self, inputs, weight, transpose_b, training):
adjacency_lists = inputs.adjacency_lists
node_embeddings = inputs.node_embeddings
node_embeddings = tf.linalg.matmul(node_embeddings, weight, transpose_b=transpose_b)

node_embeddings = tf.reshape(node_embeddings, [node_embeddings.shape[0], self.heads, -1])
aggr_out = self.propagate(GNNInput(node_embeddings, adjacency_lists), training)
if self.concat is True:
aggr_out = tf.reshape(aggr_out, [-1, self.heads * self.out_features])
else:
aggr_out = tf.reduce_mean(aggr_out, 1)
if self.use_bias:
aggr_out += self.bias
return aggr_out
3 changes: 2 additions & 1 deletion nlpgnn/gnn/GATConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def message_function(self, edge_source_states, edge_source, # x_j source
edge_target_states = tf.reshape(edge_target_states, [-1, self.heads, self.out_features]) # [M,Head,dim]
# self.att=[1,heads,2*D]
# [M,keads,2D] * [1,heads,2D]
alpha = tf.concat([edge_target_states, edge_source_states], -1) * self.att
alpha = tf.concat([edge_target_states, edge_source_states], -1) * self.att #[M,heads,2D]
alpha = tf.reduce_mean(alpha, -1)
alpha = tf.nn.leaky_relu(alpha, self.negative_slope)
alpha = masksoftmax(alpha, edge_target) # here not provide nodes num, because we have add self loop at the beginning.
Expand All @@ -96,6 +96,7 @@ def call(self, inputs, training):
adjacency_lists = inputs.adjacency_lists
node_embeddings = inputs.node_embeddings
node_embeddings = tf.linalg.matmul(node_embeddings, self.weight)

aggr_out = self.propagate(GNNInput(node_embeddings, adjacency_lists), training)
if self.concat is True:
aggr_out = tf.reshape(aggr_out, [-1, self.heads * self.out_features])
Expand Down
56 changes: 56 additions & 0 deletions nlpgnn/models/GAAE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#! encoding="utf-8"
import tensorflow as tf

from nlpgnn.gnn.GAAEConv import GraphAttentionAutoEncoder
from nlpgnn.gnn.utils import GNNInput


class GAAELayer(tf.keras.layers.Layer):
def __init__(self, hidden_dim=16, num_layers=2, heads=1, **kwargs):
super(GAAELayer, self).__init__(**kwargs)
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.heads = heads

def build(self, input_shape):
input_dim = input_shape[-1]

self.weight = []
self.weight.append(self.add_weight(
shape=(input_dim, self.heads * self.hidden_dim),
name='wt',
))
for i in range(self.num_layers - 1):
self.weight.append(self.add_weight(
shape=(self.hidden_dim, self.heads * self.hidden_dim),
name='wt',
))
self.encoder_layers = []
self.decoder_layers = []
for layer in range(self.num_layers - 1):
self.encoder_layers.append(GraphAttentionAutoEncoder(self.hidden_dim, heads=self.heads))
self.decoder_layers.append(GraphAttentionAutoEncoder(self.hidden_dim, heads=self.heads))
self.encoder_layers.append(GraphAttentionAutoEncoder(self.hidden_dim, heads=self.heads))
self.decoder_layers.append(GraphAttentionAutoEncoder(input_dim, heads=self.heads))

def encoder(self, node_embeddings, adjacency_lists, training):
for layer in range(self.num_layers):
node_embeddings = self.encoder_layers[layer](GNNInput(node_embeddings, adjacency_lists), self.weight[layer],
False, training)
return node_embeddings

def decoder(self, hidden_embeddings, adjacency_lists, training):
for layer in range(self.num_layers):
hidden_embeddings = self.decoder_layers[layer](GNNInput(hidden_embeddings, adjacency_lists),
self.weight[-(layer+1)],
True,
training)
return hidden_embeddings

def call(self, node_embeddings, adjacency_lists, training=True):
hidden_embeddings = self.encoder(node_embeddings, adjacency_lists, training)
reconstruct_embedding = self.decoder(hidden_embeddings, adjacency_lists, training)
return hidden_embeddings, reconstruct_embedding

def predict(self, node_embeddings, adjacency_lists, training=False):
return self(node_embeddings, adjacency_lists, training)
3 changes: 2 additions & 1 deletion nlpgnn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
from .TextCNN import *
from .tucker import *
from .GraphSage import *
from .TextGCN2019 import *
from .TextGCN2019 import *
from .GAAE import *
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
EMAIL = '[email protected]'
AUTHOR = 'Kaiyin Zhou'
REQUIRES_PYTHON = '>=3.6.0'
VERSION = '0.0.7'
VERSION = '0.0.0'

REQUIRED = [
'typeguard',
Expand Down
Loading

0 comments on commit 807ba46

Please sign in to comment.