Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/autodiff/rnn #1723

Open
wants to merge 8 commits into
base: devel
Choose a base branch
from
223 changes: 215 additions & 8 deletions psyneulink/core/components/functions/transferfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,7 @@
from psyneulink.core.components.functions.selectionfunctions import OneHot
from psyneulink.core.components.functions.statefulfunctions.integratorfunctions import SimpleIntegrator
from psyneulink.core.components.shellclasses import Projection
from psyneulink.core.globals.keywords import \
ADDITIVE_PARAM, ALL, BIAS, EXPONENTIAL_FUNCTION, \
GAIN, GAUSSIAN_DISTORT_FUNCTION, GAUSSIAN_FUNCTION, HAS_INITIALIZERS, HOLLOW_MATRIX, \
IDENTITY_FUNCTION, IDENTITY_MATRIX, INTERCEPT, LEAK, LINEAR_FUNCTION, LINEAR_MATRIX_FUNCTION, LOGISTIC_FUNCTION, \
TANH_FUNCTION, MATRIX_KEYWORD_NAMES, MATRIX, MATRIX_KEYWORD_VALUES, MAX_INDICATOR, MAX_VAL, MULTIPLICATIVE_PARAM, \
OFF, OFFSET, ON, PER_ITEM, PROB, PRODUCT, OUTPUT_TYPE, PROB_INDICATOR, \
RATE, RECEIVER, RELU_FUNCTION, SCALE, SLOPE, SOFTMAX_FUNCTION, STANDARD_DEVIATION, SUM,\
TRANSFER_FUNCTION_TYPE, TRANSFER_WITH_COSTS_FUNCTION, VARIANCE, VARIABLE, X_0, PREFERENCE_SET_NAME
from psyneulink.core.globals.keywords import ADDITIVE_PARAM, ALL, BIAS, EXPONENTIAL_FUNCTION, GAIN, GAUSSIAN_DISTORT_FUNCTION, GAUSSIAN_FUNCTION, HAS_INITIALIZERS, HOLLOW_MATRIX, IDENTITY_FUNCTION, IDENTITY_MATRIX, INTERCEPT, LEAK, LINEAR_FUNCTION, LINEAR_MATRIX_FUNCTION, LOGISTIC_FUNCTION, LSTM_FUNCTION, MATRIX, MATRIX_KEYWORD_NAMES, MATRIX_KEYWORD_VALUES, MAX_INDICATOR, MAX_VAL, MULTIPLICATIVE_PARAM, OFF, OFFSET, ON, OUTPUT_TYPE, PER_ITEM, PREFERENCE_SET_NAME, PROB, PROB_INDICATOR, PRODUCT, RANDOM_CONNECTIVITY_MATRIX, RATE, RECEIVER, RELU_FUNCTION, SCALE, SLOPE, SOFTMAX_FUNCTION, STANDARD_DEVIATION, SUM, TANH_FUNCTION, TRANSFER_FUNCTION_TYPE, TRANSFER_WITH_COSTS_FUNCTION, VARIABLE, VARIANCE, X_0
from psyneulink.core.globals.parameters import \
Parameter, get_validator_by_function
from psyneulink.core.globals.utilities import parameter_spec, get_global_seed, safe_len
Expand Down Expand Up @@ -2530,7 +2523,221 @@ def derivative(self, output, input=None, context=None):

return derivative

# **********************************************************************************************************************
# SoftMax
# **********************************************************************************************************************

class LSTM(TransferFunction):
componentName = LSTM_FUNCTION

def __init__(self,
default_variable=None,
params=None,
owner=None,
prefs: tc.optional(is_pref_set) = None):

super().__init__(
default_variable=default_variable,
params=params,
owner=owner,
prefs=prefs)

class Parameters(TransferFunction.Parameters):
i_input_matrix = Parameter(modulable=True)
i_hidden_matrix = Parameter(modulable=True)
i_gate_func = Parameter(default_value=Logistic())

f_input_matrix = Parameter(modulable=True)
f_hidden_matrix = Parameter(modulable=True)
f_gate_func = Parameter(default_value=Logistic())

g_input_matrix = Parameter(modulable=True)
g_hidden_matrix = Parameter(modulable=True)
g_gate_func = Parameter(default_value=Tanh())

o_input_matrix = Parameter(modulable=True)
o_hidden_matrix = Parameter(modulable=True)
o_gate_func = Parameter(default_value=Logistic())

h_gate_func = Parameter(default_value=Tanh())


def _instantiate_attributes_before_function(self, function=None, context=None):
input_size = len(self.variable[0])
hidden_size = len(self.variable[1])

# Instatiate input matrices
for param_id in ["i_input_matrix", "f_input_matrix", "g_input_matrix", "o_input_matrix"]:
param_val = getattr(self, param_id, None)
if param_val is None:
param_val = RANDOM_CONNECTIVITY_MATRIX

setattr(self, param_id, get_matrix(param_val, hidden_size, input_size, context=context))

# Instantiate hidden matrices
for param_id in ["i_hidden_matrix", "f_hidden_matrix", "g_hidden_matrix", "o_hidden_matrix"]:
param_val = getattr(self, param_id, None)
if param_val is None:
param_val = RANDOM_CONNECTIVITY_MATRIX

setattr(self, param_id, get_matrix(param_val, hidden_size, hidden_size, context=context))

# Instantiate function default variables
for param_id in ["i_gate_func", "f_gate_func","g_gate_func", "o_gate_func", "h_gate_func"]:
param_val = getattr(self, param_id)
param_val.default_variable = np.zeros(hidden_size)
param_val.defaults.variable = np.zeros(hidden_size)
param_val.variable = np.zeros(hidden_size)
param_val.default_value = np.zeros(hidden_size)
param_val.defaults.value = np.zeros(hidden_size)
param_val.value = np.zeros(hidden_size)

def _function(self,
variable=None,
context=None,
params=None,
):

x_t = variable[0]
h_prev = variable[1]
c_prev = variable[2]

# Calculate input
i_input_matrix = self._get_current_function_param("i_input_matrix", context=context)
i_hidden_matrix = self._get_current_function_param("i_hidden_matrix", context=context)
i_gate_func = self._get_current_function_param("i_gate_func", context=context)
i_t = i_gate_func(np.matmul(i_input_matrix, x_t) + np.matmul(i_hidden_matrix, h_prev))

# Calculate forget gate
f_input_matrix = self._get_current_function_param("f_input_matrix", context=context)
f_hidden_matrix = self._get_current_function_param("f_hidden_matrix", context=context)
f_gate_func = self._get_current_function_param("f_gate_func", context=context)
f_t = f_gate_func(np.matmul(f_input_matrix, x_t) + np.matmul(f_hidden_matrix, h_prev))

# Update cell state
g_input_matrix = self._get_current_function_param("g_input_matrix", context=context)
g_hidden_matrix = self._get_current_function_param("g_hidden_matrix", context=context)
g_gate_func = self._get_current_function_param("g_gate_func", context=context)
g_t = g_gate_func(np.matmul(g_input_matrix, x_t) + np.matmul(g_hidden_matrix, h_prev))
c_t = np.multiply(f_t, c_prev) + np.multiply(i_t, g_t)

# Calculate output gate
o_input_matrix = self._get_current_function_param("o_input_matrix", context=context)
o_hidden_matrix = self._get_current_function_param("o_hidden_matrix", context=context)
o_gate_func = self._get_current_function_param("o_gate_func", context=context)
o_t = o_gate_func(np.matmul(o_input_matrix, x_t) + np.matmul(o_hidden_matrix, h_prev))

# Update hidden state
h_gate_func = self._get_current_function_param("h_gate_func", context=context)
h_t = np.multiply(o_t, h_gate_func(c_t))
value = [h_t, c_t]

return value

def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
matmul = ctx.import_llvm_function("__pnl_builtin_mxm")
vecadd = ctx.import_llvm_function("__pnl_builtin_vec_add")
vechadamard = ctx.import_llvm_function("__pnl_builtin_vec_hadamard")

x_t = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(0)])
h_prev = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(1)])
c_prev = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(2)])

def _mxv(m, v):
tmp = builder.alloca(h_prev.type.pointee)
tmp_ptr = builder.gep(tmp, [ctx.int32_ty(0),
ctx.int32_ty(0)])
dim_x = len(m.type.pointee)
dim_y = len(m.type.pointee.elements[0])
m_ptr = builder.gep(m, [ctx.int32_ty(0),
ctx.int32_ty(0),
ctx.int32_ty(0)])
v_ptr = builder.gep(v, [ctx.int32_ty(0),
ctx.int32_ty(0)])

builder.call(matmul, [m_ptr,
v_ptr,
ctx.int32_ty(dim_x),
ctx.int32_ty(dim_y),
ctx.int32_ty(1),
tmp_ptr])

return tmp

def _vxv(v1, v2):
tmp = builder.alloca(h_prev.type.pointee)
tmp_ptr = builder.gep(tmp, [ctx.int32_ty(0),
ctx.int32_ty(0)])
dim_x = len(v1.type.pointee)
v1_ptr = builder.gep(v1, [ctx.int32_ty(0),
ctx.int32_ty(0)])
v2_ptr = builder.gep(v2, [ctx.int32_ty(0),
ctx.int32_ty(0)])

builder.call(vechadamard, [v1_ptr,
v2_ptr,
ctx.int32_ty(dim_x),
tmp_ptr])
return tmp

def _mac(m1, v1, m2, v2, mul_op=_mxv):
val1 = mul_op(m1, v1)
val2 = mul_op(m2, v2)
val1_ptr = builder.gep(val1, [ctx.int32_ty(0),
ctx.int32_ty(0)])
val2_ptr = builder.gep(val2, [ctx.int32_ty(0),
ctx.int32_ty(0)])
builder.call(vecadd, [val1_ptr,
val2_ptr,
ctx.int32_ty(len(m1.type.pointee)),
val1_ptr])
return val1

def _call_func(func_id, in_vec, out_vec):
param_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, func_id)
state_ptr = pnlvm.helpers.get_state_ptr(builder, self, state, func_id)

llvm_func = ctx.import_llvm_function(getattr(self, func_id), tags=tags)
builder.call(llvm_func, [param_ptr, state_ptr, in_vec, out_vec])

# Calculate input
i_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'i_input_matrix')
i_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'i_hidden_matrix')
i_t = _mac(i_input_matrix, x_t, i_hidden_matrix, h_prev)
_call_func("i_gate_func", i_t, i_t)

# Calculate forget gate
f_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'f_input_matrix')
f_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'f_hidden_matrix')
f_t = _mac(f_input_matrix, x_t, f_hidden_matrix, h_prev)
_call_func("f_gate_func", f_t, f_t)

# Update cell state
g_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'g_input_matrix')
g_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'g_hidden_matrix')
g_t = _mac(g_input_matrix, x_t, g_hidden_matrix, h_prev)
_call_func("g_gate_func", g_t, g_t)

c_t = _mac(f_t, c_prev, i_t, g_t, mul_op=_vxv)

# Calculate output gate
o_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'o_input_matrix')
o_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'o_hidden_matrix')
o_t = _mac(o_input_matrix, x_t, o_hidden_matrix, h_prev)
_call_func("o_gate_func", o_t, o_t)

# Update hidden state
h_t = builder.alloca(h_prev.type.pointee)
_call_func("h_gate_func", c_t, h_t)
h_t = _vxv(o_t, h_t)

# Writeback into value struct
builder.store(builder.load(h_t), builder.gep(arg_out, [ctx.int32_ty(0),
ctx.int32_ty(0)]))
builder.store(builder.load(c_t), builder.gep(arg_out, [ctx.int32_ty(0),
ctx.int32_ty(1)]))

return builder
# **********************************************************************************************************************
# LinearMatrix
# **********************************************************************************************************************
Expand Down
39 changes: 22 additions & 17 deletions psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6743,7 +6743,7 @@ def bfs(start):
pathways.append(p)
continue
for projection, efferent_node in [(p, p.receiver.owner) for p in curr_node.efferents]:
if (not hasattr(projection,'learnable')) or (projection.learnable is False):
if getattr(projection, 'learnable', False) is False or efferent_node in prev:
continue
prev[efferent_node] = projection
prev[projection] = curr_node
Expand Down Expand Up @@ -7506,6 +7506,23 @@ def evaluate(
else:
return net_outcome

def _infer_target_node(self, node):
if (NodeRole.TARGET not in self.get_roles_by_node(node) and NodeRole.LEARNING not in self.get_roles_by_node(node)):
node_efferent_mechanisms = [x.receiver.owner for x in node.efferents if x in self.projections]
comparators = [x for x in node_efferent_mechanisms
if (isinstance(x, ComparatorMechanism)
and NodeRole.LEARNING in self.get_roles_by_node(x))]
comparator_afferent_mechanisms = [x.sender.owner for c in comparators for x in c.afferents]
target_nodes = [t for t in comparator_afferent_mechanisms
if (NodeRole.TARGET in self.get_roles_by_node(t)
and NodeRole.LEARNING in self.get_roles_by_node(t))]

if len(target_nodes) != 1:
# Invalid specification: no valid target nodes or ambiguity in which target node to choose
raise Exception(f"Unable to infer learning target node from output node {node} of {self.name}")

if len(target_nodes) > 0:
return target_nodes[0]

def _infer_target_nodes(self, targets: dict):
"""
Expand All @@ -7519,22 +7536,10 @@ def _infer_target_nodes(self, targets: dict):
"""
ret = {}
for node, values in targets.items():
if (NodeRole.TARGET not in self.get_roles_by_node(node)
and NodeRole.LEARNING not in self.get_roles_by_node(node)):
node_efferent_mechanisms = [x.receiver.owner for x in node.efferents if x in self.projections]
comparators = [x for x in node_efferent_mechanisms
if (isinstance(x, ComparatorMechanism)
and NodeRole.LEARNING in self.get_roles_by_node(x))]
comparator_afferent_mechanisms = [x.sender.owner for c in comparators for x in c.afferents]
target_nodes = [t for t in comparator_afferent_mechanisms
if (NodeRole.TARGET in self.get_roles_by_node(t)
and NodeRole.LEARNING in self.get_roles_by_node(t))]

if len(target_nodes) != 1:
# Invalid specification: no valid target nodes or ambiguity in which target node to choose
raise Exception(f"Unable to infer learning target node from output node {node} of {self.name}")

ret[target_nodes[0]] = values
target_node = self._infer_target_node(node)

if target_node is not None:
ret[target_node] = values
else:
ret[node] = values
return ret
Expand Down
4 changes: 3 additions & 1 deletion psyneulink/core/globals/keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
'LEARNING_PATHWAY', 'LEARNING_PROJECTION', 'LEARNING_PROJECTION_PARAMS', 'LEARNING_RATE', 'LEARNING_SIGNAL',
'LEARNING_SIGNAL_SPECS', 'LEARNING_SIGNALS',
'LESS_THAN', 'LESS_THAN_OR_EQUAL', 'LINEAR', 'LINEAR_COMBINATION_FUNCTION', 'LINEAR_FUNCTION',
'LINEAR_MATRIX_FUNCTION', 'LOG_ENTRIES', 'LOGISTIC_FUNCTION', 'LOW', 'LVOC_CONTROL_MECHANISM', 'L0', 'L1',
'LINEAR_MATRIX_FUNCTION', 'LOG_ENTRIES', 'LOGISTIC_FUNCTION', 'LOW', 'LSTM_FUNCTION', 'LVOC_CONTROL_MECHANISM', 'L0', 'L1',
'MAPPING_PROJECTION', 'MAPPING_PROJECTION_PARAMS', 'MASKED_MAPPING_PROJECTION',
'MATRIX', 'MATRIX_KEYWORD_NAMES', 'MATRIX_KEYWORD_SET', 'MATRIX_KEYWORD_VALUES', 'MATRIX_KEYWORDS','MatrixKeywords',
'MAX_ABS_DIFF', 'MAX_ABS_INDICATOR', 'MAX_ONE_HOT', 'MAX_ABS_ONE_HOT', 'MAX_ABS_VAL',
Expand Down Expand Up @@ -526,6 +526,7 @@ def _is_metric(metric):
TRANSFER_MECHANISM = "TransferMechanism"
LEABRA_MECHANISM = "LeabraMechanism"
RECURRENT_TRANSFER_MECHANISM = "RecurrentTransferMechanism"
LSTM_MECHANISM = "LSTMMechanism"
CONTRASTIVE_HEBBIAN_MECHANISM = "ContrastiveHebbianMechanism"
LCA_MECHANISM = "LCAMechanism"
KOHONEN_MECHANISM = 'KohonenMechanism'
Expand Down Expand Up @@ -557,6 +558,7 @@ def _is_metric(metric):
GAUSSIAN_FUNCTION = "Gaussian Function"
GAUSSIAN_DISTORT_FUNCTION = "GaussianDistort Function"
SOFTMAX_FUNCTION = 'SoftMax Function'
LSTM_FUNCTION = 'LSTM Function'
LINEAR_MATRIX_FUNCTION = "LinearMatrix Function"
TRANSFER_WITH_COSTS_FUNCTION = "TransferWithCosts Function"

Expand Down
2 changes: 2 additions & 0 deletions psyneulink/core/llvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def init_builtins():
with LLVMBuilderContext.get_global() as ctx:
builtins.setup_pnl_intrinsics(ctx)
builtins.setup_vxm(ctx)
builtins.setup_mxm(ctx)
builtins.setup_vxm_transposed(ctx)
builtins.setup_mersenne_twister(ctx)
builtins.setup_vec_add(ctx)
Expand All @@ -139,6 +140,7 @@ def init_builtins():
builtins.setup_vec_copy(ctx)
builtins.setup_vec_hadamard(ctx)
builtins.setup_mat_hadamard(ctx)
builtins.setup_vec_outer_product(ctx)
builtins.setup_vec_scalar_mult(ctx)
builtins.setup_mat_scalar_mult(ctx)
builtins.setup_mat_scalar_add(ctx)
Expand Down
Loading