-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
49 changed files
with
6,162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
function A = assocarray(keys, vals) | ||
% Make an associative array | ||
% % | ||
% keys{i} is the i'th string, vals{i} is the i'th value. | ||
% After construction, A('foo') will return the value associated with foo. | ||
|
||
A.keys = keys; | ||
A.vals = vals; | ||
A = class(A, 'assocarray'); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
function val = subsref(A, S) | ||
% SUBSREF Subscript reference for an associative array | ||
% A('foo') will return the value associated with foo. | ||
% If there are multiple identicaly keys, the first match is returned. | ||
% Currently the search is sequential. | ||
|
||
i = 1; | ||
while i <= length(A.keys) | ||
if strcmp(S.subs{1}, A.keys{i}) | ||
val = A.vals{i}; | ||
return; | ||
end | ||
i = i + 1; | ||
end | ||
error(['can''t find ' S.subs{1}]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
function p = adjustable_CPD(CPD) | ||
% Check if this CPD have any adjustable params (gaussian) | ||
|
||
% % This function was adapted from Bayes Net Toolbox written by Kevin Murphy | ||
|
||
p = ~CPD.clamped_mean | ~CPD.clamped_cov | ~CPD.clamped_weights; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
function CPD = gaussian_CPD(bnet, self, varargin) | ||
% GAUSSIAN_CPD Make a conditional linear Gaussian distrib. | ||
% | ||
% CPD = gaussian_CPD(bnet, node, ...) will create a CPD with random parameters, | ||
% where node is the number of a node in this equivalence class. | ||
|
||
% To define this CPD precisely, call the continuous (cts) parents (if any) X, | ||
% the discrete parents (if any) Q, and this node Y. Then the distribution on Y is: | ||
% - no parents: Y ~ N(mu, Sigma) | ||
% - cts parents : Y|X=x ~ N(mu + W x, Sigma) | ||
% - discrete parents: Y|Q=i ~ N(mu(i), Sigma(i)) | ||
% - cts and discrete parents: Y|X=x,Q=i ~ N(mu(i) + W(i) x, Sigma(i)) | ||
% | ||
% The list below gives optional arguments [default value in brackets]. | ||
% (Let ns(i) be the size of node i, X = ns(X), Y = ns(Y) and Q = prod(ns(Q)).) | ||
% Parameters will be reshaped to the right size if necessary. | ||
% | ||
% mean - mu(:,i) is the mean given Q=i [ randn(Y,Q) ] | ||
% cov - Sigma(:,:,i) is the covariance given Q=i [ repmat(100*eye(Y,Y), [1 1 Q]) ] | ||
% weights - W(:,:,i) is the regression matrix given Q=i [ randn(Y,X,Q) ] | ||
% cov_type - if 'diag', Sigma(:,:,i) is diagonal [ 'full' ] | ||
% tied_cov - if 1, we constrain Sigma(:,:,i) to be the same for all i [0] | ||
% clamp_mean - if 1, we do not adjust mu(:,i) during learning [0] | ||
% clamp_cov - if 1, we do not adjust Sigma(:,:,i) during learning [0] | ||
% clamp_weights - if 1, we do not adjust W(:,:,i) during learning [0] | ||
% cov_prior_weight - weight given to I prior for estimating Sigma [0.01] | ||
% cov_prior_entropic - if 1, we also use an entropic prior for Sigma [0] | ||
% | ||
% e.g., CPD = gaussian_CPD(bnet, i, 'mean', [0; 0], 'clamp_mean', 1) | ||
|
||
% % This function was adapted from Bayes Net Toolbox written by Kevin Murphy | ||
|
||
if nargin==0 | ||
% This occurs if we are trying to load an object from a file. | ||
CPD = init_fields; | ||
clamp = 0; | ||
CPD = class(CPD, 'gaussian_CPD', generic_CPD(clamp)); | ||
return; | ||
elseif isa(bnet, 'gaussian_CPD') | ||
% This might occur if we are copying an object. | ||
CPD = bnet; | ||
return; | ||
end | ||
CPD = init_fields; | ||
|
||
CPD = class(CPD, 'gaussian_CPD', generic_CPD(0)); | ||
|
||
args = varargin; | ||
ns = bnet.node_sizes; | ||
ps = parents(bnet.dag, self); | ||
dps = myintersect(ps, bnet.dnodes); | ||
cps = myintersect(ps, bnet.cnodes); | ||
fam_sz = ns([ps self]); | ||
|
||
CPD.self = self; | ||
CPD.sizes = fam_sz; | ||
|
||
% Figure out which (if any) of the parents are discrete, and which cts, and how big they are | ||
% dps = discrete parents, cps = cts parents | ||
CPD.cps = find_equiv_posns(cps, ps); % cts parent index | ||
CPD.dps = find_equiv_posns(dps, ps); | ||
ss = fam_sz(end); | ||
psz = fam_sz(1:end-1); | ||
dpsz = prod(psz(CPD.dps)); | ||
cpsz = sum(psz(CPD.cps)); | ||
|
||
% set default params | ||
CPD.mean = randn(ss, dpsz); | ||
CPD.cov = 100*repmat(eye(ss), [1 1 dpsz]); | ||
CPD.weights = randn(ss, cpsz, dpsz); | ||
CPD.cov_type = 'full'; | ||
CPD.tied_cov = 0; | ||
CPD.clamped_mean = 0; | ||
CPD.clamped_cov = 0; | ||
CPD.clamped_weights = 0; | ||
CPD.cov_prior_weight = 0.01; | ||
CPD.cov_prior_entropic = 0; | ||
nargs = length(args); | ||
if nargs > 0 | ||
CPD = set_fields(CPD, args{:}); | ||
end | ||
|
||
% Make sure the matrices have 1 dimension per discrete parent. | ||
CPD.mean = myreshape(CPD.mean, [ss ns(dps)]); | ||
CPD.cov = myreshape(CPD.cov, [ss ss ns(dps)]); | ||
CPD.weights = myreshape(CPD.weights, [ss cpsz ns(dps)]); | ||
|
||
% Precompute indices into block structured matrices | ||
% to speed up CPD_to_lambda_msg and CPD_to_pi | ||
cpsizes = CPD.sizes(CPD.cps); | ||
CPD.cps_block_ndx = cell(1, length(cps)); | ||
for i=1:length(cps) | ||
CPD.cps_block_ndx{i} = block(i, cpsizes); | ||
end | ||
|
||
%%%%%%%%%%% | ||
% Learning stuff | ||
|
||
% expected sufficient statistics | ||
CPD.Wsum = zeros(dpsz,1); | ||
CPD.WYsum = zeros(ss, dpsz); | ||
CPD.WXsum = zeros(cpsz, dpsz); | ||
CPD.WYYsum = zeros(ss, ss, dpsz); | ||
CPD.WXXsum = zeros(cpsz, cpsz, dpsz); | ||
CPD.WXYsum = zeros(cpsz, ss, dpsz); | ||
|
||
% For BIC | ||
CPD.nsamples = 0; | ||
switch CPD.cov_type | ||
case 'full', | ||
% since symmetric | ||
%ncov_params = ss*(ss-1)/2; | ||
ncov_params = ss*(ss+1)/2; | ||
case 'diag', | ||
ncov_params = ss; | ||
otherwise | ||
error(['unrecognized cov_type ' cov_type]); | ||
end | ||
% params = weights + mean + cov | ||
if CPD.tied_cov | ||
CPD.nparams = ss*cpsz*dpsz + ss*dpsz + ncov_params; | ||
else | ||
CPD.nparams = ss*cpsz*dpsz + ss*dpsz + dpsz*ncov_params; | ||
end | ||
|
||
% for speeding up maximize_params | ||
CPD.useC = exist('rep_mult'); | ||
|
||
clamped = CPD.clamped_mean & CPD.clamped_cov & CPD.clamped_weights; | ||
CPD = set_clamped(CPD, clamped); | ||
|
||
%%%%%%%%%%% | ||
|
||
function CPD = init_fields() | ||
% This ensures we define the fields in the same order | ||
|
||
CPD.self = []; | ||
CPD.sizes = []; | ||
CPD.cps = []; | ||
CPD.dps = []; | ||
CPD.mean = []; | ||
CPD.cov = []; | ||
CPD.weights = []; | ||
CPD.clamped_mean = []; | ||
CPD.clamped_cov = []; | ||
CPD.clamped_weights = []; | ||
CPD.cov_type = []; | ||
CPD.tied_cov = []; | ||
CPD.Wsum = []; | ||
CPD.WYsum = []; | ||
CPD.WXsum = []; | ||
CPD.WYYsum = []; | ||
CPD.WXXsum = []; | ||
CPD.WXYsum = []; | ||
CPD.nsamples = []; | ||
CPD.nparams = []; | ||
CPD.cov_prior_weight = []; | ||
CPD.cov_prior_entropic = []; | ||
CPD.useC = []; | ||
CPD.cps_block_ndx = []; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
function CPD = learn_params(CPD, fam, data, ns, cnodes) | ||
|
||
% Compute the maximum likelihood estimate of the params of a gaussian CPD given complete data | ||
% | ||
% | ||
% data(i,m) is the value of node i in case m (can be cell array). | ||
% % This function was adapted from Bayes Net Toolbox written by Kevin Murphy | ||
|
||
ncases = size(data, 2); | ||
CPD = reset_ess(CPD); | ||
% make a fully observed joint distribution over the family | ||
fmarginal.domain = fam; | ||
fmarginal.T = 1; | ||
fmarginal.mu = []; | ||
fmarginal.Sigma = []; | ||
|
||
hidden_bitv = zeros(1, max(fam)); | ||
|
||
for m=1:ncases | ||
% specify (as a bit vector) which elements in the family domain are hidden | ||
hidden_bitv = zeros(1, max(fmarginal.domain)); | ||
ev = data(:,m); | ||
hidden_bitv(isempty(ev))=1; | ||
CPD = update_ess(CPD, fmarginal, ev, ns, cnodes, hidden_bitv); | ||
end | ||
CPD = maximize_params(CPD); | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
function CPD = maximize_params(CPD) | ||
% Set the params of a CPD to their ML values (Gaussian) | ||
% | ||
% This function was adapted from Bayes Net Toolbox written by Kevin Murphy | ||
|
||
if ~adjustable_CPD(CPD), return; end | ||
|
||
|
||
if CPD.clamped_mean | ||
cl_mean = CPD.mean; | ||
else | ||
cl_mean = []; | ||
end | ||
|
||
if CPD.clamped_cov | ||
cl_cov = CPD.cov; | ||
else | ||
cl_cov = []; | ||
end | ||
|
||
if CPD.clamped_weights | ||
cl_weights = CPD.weights; | ||
else | ||
cl_weights = []; | ||
end | ||
|
||
[ssz psz Q] = size(CPD.weights); | ||
|
||
[ss cpsz dpsz] = size(CPD.weights); % ss = self size = ssz | ||
if cpsz > CPD.nsamples | ||
fprintf('gaussian_CPD/maximize_params: warning: input dimension (%d) > nsamples (%d)\n', ... | ||
cpsz, CPD.nsamples); | ||
end | ||
|
||
prior = repmat(CPD.cov_prior_weight*eye(ssz,ssz), [1 1 Q]); | ||
|
||
|
||
[CPD.mean, CPD.cov, CPD.weights] = ... | ||
clg_Mstep(CPD.Wsum, CPD.WYsum, CPD.WYYsum, [], CPD.WXsum, CPD.WXXsum, CPD.WXYsum, ... | ||
'cov_type', CPD.cov_type, 'clamped_mean', cl_mean, ... | ||
'clamped_cov', cl_cov, 'clamped_weights', cl_weights, ... | ||
'tied_cov', CPD.tied_cov, ... | ||
'cov_prior', prior); | ||
|
||
if 0 | ||
CPD.mean = reshape(CPD.mean, [ss dpsz]); | ||
CPD.cov = reshape(CPD.cov, [ss ss dpsz]); | ||
CPD.weights = reshape(CPD.weights, [ss cpsz dpsz]); | ||
end | ||
|
||
sz = CPD.sizes; | ||
ss = sz(end); | ||
|
||
cpsz = sum(sz(CPD.cps)); | ||
|
||
dpsz = sz(CPD.dps); | ||
CPD.mean = myreshape(CPD.mean, [ss dpsz]); | ||
CPD.cov = myreshape(CPD.cov, [ss ss dpsz]); | ||
CPD.weights = myreshape(CPD.weights, [ss cpsz dpsz]); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
function CPD = reset_ess(CPD) | ||
% Reset the Expected Sufficient Statistics for a Gaussian CPD. | ||
% % This function was adapted from Bayes Net Toolbox written by Kevin Murphy | ||
|
||
CPD.nsamples = 0; | ||
CPD.Wsum = zeros(size(CPD.Wsum)); | ||
CPD.WYsum = zeros(size(CPD.WYsum)); | ||
CPD.WYYsum = zeros(size(CPD.WYYsum)); | ||
CPD.WXsum = zeros(size(CPD.WXsum)); | ||
CPD.WXXsum = zeros(size(CPD.WXXsum)); | ||
CPD.WXYsum = zeros(size(CPD.WXYsum)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
function CPD = set_fields(CPD, varargin) | ||
% Set the parameters (fields) for a gaussian_CPD object | ||
% | ||
% mean - mu(:,i) is the mean given Q=i | ||
% cov - Sigma(:,:,i) is the covariance given Q=i | ||
% weights - W(:,:,i) is the regression matrix given Q=i | ||
% cov_type - if 'diag', Sigma(:,:,i) is diagonal | ||
% tied_cov - if 1, we constrain Sigma(:,:,i) to be the same for all i | ||
% clamp_mean - if 1, we do not adjust mu(:,i) during learning | ||
% clamp_cov - if 1, we do not adjust Sigma(:,:,i) during learning | ||
% clamp_weights - if 1, we do not adjust W(:,:,i) during learning | ||
% clamp - if 1, we do not adjust any params | ||
% cov_prior_weight - weight given to I prior for estimating Sigma | ||
% cov_prior_entropic - if 1, we also use an entropic prior for Sigma [0] | ||
% | ||
% e.g., CPD = set_params(CPD, 'mean', [0;0]) | ||
% This function was adapted from Bayes Net Toolbox written by Kevin Murphy | ||
|
||
args = varargin; | ||
nargs = length(args); | ||
for i=1:2:nargs | ||
switch args{i}, | ||
case 'mean', CPD.mean = args{i+1}; | ||
case 'cov', CPD.cov = args{i+1}; | ||
case 'weights', CPD.weights = args{i+1}; | ||
case 'cov_type', CPD.cov_type = args{i+1}; | ||
%case 'tied_cov', CPD.tied_cov = strcmp(args{i+1}, 'yes'); | ||
case 'tied_cov', CPD.tied_cov = args{i+1}; | ||
case 'clamp_mean', CPD.clamped_mean = args{i+1}; | ||
case 'clamp_cov', CPD.clamped_cov = args{i+1}; | ||
case 'clamp_weights', CPD.clamped_weights = args{i+1}; | ||
case 'clamp', clamp = args{i+1}; | ||
CPD.clamped_mean = clamp; | ||
CPD.clamped_cov = clamp; | ||
CPD.clamped_weights = clamp; | ||
case 'cov_prior_weight', CPD.cov_prior_weight = args{i+1}; | ||
case 'cov_prior_entropic', CPD.cov_prior_entropic = args{i+1}; | ||
otherwise, | ||
error(['invalid argument name ' args{i}]); | ||
end | ||
end |
Oops, something went wrong.