From 834d90ca8a5e613c28d12ceca04f07dc51e4cd21 Mon Sep 17 00:00:00 2001 From: Xinyu Zhou Date: Tue, 10 Dec 2013 12:51:46 +0800 Subject: [PATCH] basic gmm worked --- Makefile | 2 +- gmm.py | 32 +++++++ plot-gmm.py | 112 +++++++++++++++++++++++++ plot-point-3d.py | 212 +++++++++++++++++++++++++++++++++++++++++++++++ plot-point.py | 175 ++++++++++++++++++++++++++++++++++++++ src/.gmm.cc.swp | Bin 0 -> 20480 bytes src/gmm.cc | 133 ++++++++++++++++++++--------- src/gmm.hh | 7 +- src/main.cc | 33 +++++--- src/random.hh | 3 +- 10 files changed, 657 insertions(+), 52 deletions(-) create mode 100755 gmm.py create mode 100755 plot-gmm.py create mode 100755 plot-point-3d.py create mode 100755 plot-point.py create mode 100644 src/.gmm.cc.swp diff --git a/Makefile b/Makefile index 1d8a5dd..ae8876c 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ # # $File: Makefile -# $Date: Mon Dec 09 00:40:18 2013 +0800 +# $Date: Tue Dec 10 11:57:58 2013 +0800 # # A single output portable Makefile for # simple c++ project diff --git a/gmm.py b/gmm.py new file mode 100755 index 0000000..87b8211 --- /dev/null +++ b/gmm.py @@ -0,0 +1,32 @@ +#!/usr/bin/python2 +# -*- coding: utf-8 -*- +# $File: gmm.py +# $Date: Tue Dec 10 11:36:41 2013 +0800 +# $Author: Xinyu Zhou + +from sklearn.mixture import GMM + + +def read_data(fname): + with open(fname) as fin: + return map(lambda line: map(float, line.rstrip().split()), fin) + +def dump_gmm(gmm): + print gmm.n_components + print " " . join(map(str, gmm.weights_)) + for i in range(gmm.n_components): + print len(gmm.means_[i]), 1 + print " " . join(map(str, gmm.means_[i])) + print " " . join(map(str, gmm.covars_[i])) + +def main(): + gmm = GMM(3) + X = read_data('test.data') + gmm.fit(X) + dump_gmm(gmm) + +if __name__ == '__main__': + main() + + +# vim: foldmethod=marker diff --git a/plot-gmm.py b/plot-gmm.py new file mode 100755 index 0000000..d6f1b5c --- /dev/null +++ b/plot-gmm.py @@ -0,0 +1,112 @@ +#!/usr/bin/python2 +# -*- coding: utf-8 -*- +# $File: plot-gmm.py +# $Date: Tue Dec 10 12:45:04 2013 +0800 +# $Author: Xinyu Zhou + +import matplotlib.pyplot as plt +import matplotlib.mlab as mlab +from matplotlib import cm +from scipy import stats, mgrid, c_, reshape, random, rot90 +import argparse +from numpy import * +import numpy as np + +class GassianTypeNotImplemented(Exception): + pass + +def get_args(): + description = 'plot gmm' + parser = argparse.ArgumentParser(description = description) + + parser.add_argument('-i', '--input', help = 'data file', required = True) + parser.add_argument('-m', '--model', help = 'model file', required = True) + + args = parser.parse_args() + + return args + + +class Gaussian(object): + def __init__(self): + self.covtype = 1 + self.dim = 0 + self.mean = array([]) + self.sigma = array([]) + self.covariance = array([[]]) + + def probability_of(self, x): + assert len(x) == self.dim + + return exp((x - mean)**2 / (2 * self.sigma**2)) / (2 * pi * self.sigma) + + +class GMM(object): + def __init__(self): + self.nr_mixtures = 0 + self.weights = array([]) + self.gaussians = [] + +def read_data(fname): + with open(fname) as fin: + return zip(*map( lambda line: map(float, line.rstrip().split()), fin)) + + +def read_gaussian(fin): + gaussian = Gaussian() + gaussian.dim, gaussian.covtype = map(int, fin.readline().rstrip().split()) + if gaussian.covtype == 1: + gaussian.mean = map(float, fin.readline().rstrip().split()) + gaussian.sigma = map(float, fin.readline().rstrip().split()) + assert len(gaussian.mean) == gaussian.dim + assert len(gaussian.sigma) == gaussian.dim + else: + raise GassianTypeNotImplemented() + return gaussian + +def read_model(fname): + gmm = GMM() + with open(fname) as fin: + gmm.nr_mixtures = int(fin.readline().rstrip()) + gmm.weights = map(float, fin.readline().rstrip().split()) + for i in range(gmm.nr_mixtures): + gmm.gaussians.append(read_gaussian(fin)) + + return gmm + +def main(): + args = get_args() + data = read_data(args.input) + x, y = data[:2] + gmm = read_model(args.model) + + fig = plt.figure() + ax = fig.add_subplot(111, aspect = 'equal') + ax.scatter(x, y) + x0, x1, y0, y1 = ax.axis() + + x = linspace(x0, x1, 1000) + y = linspace(y0, y1, 1000) + X, Y = meshgrid(x, y) + + def get_Z(X, Y, gaussian): + return mlab.bivariate_normal(X, Y, gaussian.sigma[0], gaussian.sigma[1], + gaussian.mean[0], gaussian.mean[1], 0) + + Z = get_Z(X, Y, gmm.gaussians[0]) + for gaussian in gmm.gaussians[1:]: + Z += get_Z(X, Y, gaussian) + plt.contour(X, Y, Z, cmap=cm.PuBu_r) + for gaussian in gmm.gaussians: + print gaussian.mean + plt.scatter(gaussian.mean[0], gaussian.mean[1], s = 50, c = 'yellow') + + plt.show() + + +if __name__ == '__main__': + main() + + +# vim: foldmethod=marker + diff --git a/plot-point-3d.py b/plot-point-3d.py new file mode 100755 index 0000000..87bf4ef --- /dev/null +++ b/plot-point-3d.py @@ -0,0 +1,212 @@ +#!/usr/bin/python2 + +import numpy as np +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +import argparse, sys + +stdin_fname = '$stdin$' + +def get_args(): + description = "plot points into graph. x and y seperated with white space in one line, or just y's" + parser = argparse.ArgumentParser(description = description) + parser.add_argument('-i', '--input', + help = 'input data file, "-" for stdin, default stdin', + default = '-') + parser.add_argument('-o', '--output', + help = 'output image', default = '') + parser.add_argument('--show', + help = 'show the figure after rendered', + action = 'store_true') + parser.add_argument('-t', '--title', + help = 'title of the graph', + default = '') + parser.add_argument('--xlabel', + help = 'x label', + default = 'x') + parser.add_argument('--ylabel', + help = 'y label', + default = 'y') + parser.add_argument('--zlabel', + help = 'z label', + default = 'z') + parser.add_argument('--xlim', help = 'xlim') + parser.add_argument('--ylim', help = 'ylim') + parser.add_argument('--zlim', help = 'zlim') + + parser.add_argument('--annotate-maximum', + help = 'annonate maximum value in graph', + action = 'store_true') + parser.add_argument('--annotate-minimum', + help = 'annonate minimum value in graph', + action = 'store_true') + parser.add_argument('--xkcd', + help = 'xkcd style', + action = 'store_true') + + args = parser.parse_args(); + + if (not args.show) and len(args.output) == 0: + raise Exception("at least one of --show and --output/-o must be specified") + if args.xlim: + args.xlim = map(float, args.xlim.rstrip().split(',')) + if args.ylim: + args.ylim = map(float, args.ylim.rstrip().split(',')) + if args.zlim: + args.zlim = map(float, args.zlim.rstrip().split(',')) + + return args + + +def filter_valid_range(points, rect): + """rect = (min_x, max_x, min_y, max_y)""" + ret = [] + for x, y in points: + if x >= rect[0] and x <= rect[1] and y >= rect[2] and y <= rect[3]: + ret.append((x, y)) + if len(ret) == 0: + ret.append(points[0]) + return ret + +def do_plot(data_x, data_y, data_z, args): + fig = plt.figure(figsize = (16.18, 10)) + projection = '2d' + if len(data_z) > 0: + projection = '3d' + ax = fig.add_axes((0.1, 0.2, 0.8, 0.7), projection = projection) + if projection == '2d': + ax.scatter(data_x, data_y) + else: + ax.scatter(data_x, data_y, data_z) + if args.xlim: + ax.set_xlim(args.xlim) + if args.ylim: + ax.set_ylim(args.ylim) + if args.zlim: + ax.set_zlim3d(args.zlim) + if args.xlim or args.ylim or args.zlim: + pass + ax.set_aspect('equal') + else: + ax.set_aspect('equal', 'datalim') + #ax.spines['right'].set_color('none') + #ax.spines['left'].set_color('none') + #plt.xticks([]) + #plt.yticks([]) + + if args.annotate_maximum or args.annotate_minimum: + max_x, min_x = max(data_x), min(data_x) + max_y, min_y = max(data_y), min(data_y) + x_range = max_x - min_x + y_range = max_y - min_y + x_max, y_max = data_y[0], data_y[0] + x_min, y_min = data_x[0], data_y[0] + + rect = ax.axis() + + for i in xrange(1, len(data_x)): + if data_y[i] > y_max: + y_max = data_y[i] + x_max = data_x[i] + if data_y[i] < y_min: + y_min = data_y[i] + x_min = data_x[i] + if args.annotate_maximum: + text_x, text_y = filter_valid_range([ + (x_max + 0.05 * x_range, + y_max + 0.025 * y_range), + (x_max - 0.05 * x_range, + y_max + 0.025 * y_range), + (x_max + 0.05 * x_range, + y_max - 0.025 * y_range), + (x_max - 0.05 * x_range, + y_max - 0.025 * y_range)], + rect)[0] + ax.annotate('maximum ({:.3f},{:.3f})' . format(x_max, y_max), + xy = (x_max, y_max), + xytext = (text_x, text_y), + arrowprops = dict(arrowstyle = '->')) + if args.annotate_minimum: + text_x, text_y = filter_valid_range([ + (x_min + 0.05 * x_range, + y_min - 0.025 * y_range), + (x_min - 0.05 * x_range, + y_min - 0.025 * y_range), + (x_min + 0.05 * x_range, + y_min + 0.025 * y_range), + (x_min - 0.05 * x_range, + y_min + 0.025 * y_range)], + rect)[0] + ax.annotate('minimum ({:.3f},{:.3f})' . format(x_min, y_min), + xy = (x_min, y_min), + xytext = (text_x, text_y), + arrowprops = dict(arrowstyle = '->')) + + ax.set_xlabel(args.xlabel) + ax.set_ylabel(args.ylabel) + if projection == '3d': + ax.set_zlabel(args.zlabel) + + fig.text(0.5, 0.05, args.title, ha = 'center') + if args.output != '': + plt.savefig(args.output) + + if args.show: + plt.show() + +def main(): + args = get_args() + if args.input == stdin_fname: + fin = sys.stdin + else: + fin = open(args.input) + + data_x = [] + data_y = [] + data_z = [] + data_format = -1 + for lineno, line in enumerate(fin.readlines()): + line = [float(i) for i in line.rstrip().split()] + line_data_format = -1 + x, y, z = None, None, None + if len(line) == 0: + continue + if len(line) == 2: + line_data_format = 0 + x, y = line + elif len(line) == 1: + line_data_format = 1 + x, y = lineno, line[0] + elif len(line) == 3: + x, y, z = line + line_data_format = 2; + else: + raise RuntimeError('Can not parse input data at line {}' . format(lineno + 1)) + + if data_format == -1: + data_format = line_data_format + else: + if line_data_format != data_format: + raise RuntimeError('data format is not consistent, at line {}' \ + . format(lineno + 1)) + data_x.append(x) + data_y.append(y) + if z != None: + data_z.append(z) + print len(data_x) + if args.input != stdin_fname: + fin.close() + + if len(data_x) == 1: + return + + if args.xkcd: + with plt.xkcd(): + do_plot(data_x, data_y, data_z, args) + else: + do_plot(data_x, data_y, data_z, args) + + + +if __name__ == '__main__': + main() diff --git a/plot-point.py b/plot-point.py new file mode 100755 index 0000000..a372a92 --- /dev/null +++ b/plot-point.py @@ -0,0 +1,175 @@ +#!/usr/bin/python2 + +import numpy as np +import matplotlib.pyplot as plt +import argparse, sys + +stdin_fname = '$stdin$' + +def get_args(): + description = "plot points into graph. x and y seperated with white space in one line, or just y's" + parser = argparse.ArgumentParser(description = description) + parser.add_argument('-i', '--input', + help = 'input data file, "-" for stdin, default stdin', + default = '-') + parser.add_argument('-o', '--output', + help = 'output image', default = '') + parser.add_argument('--show', + help = 'show the figure after rendered', + action = 'store_true') + parser.add_argument('-t', '--title', + help = 'title of the graph', + default = '') + parser.add_argument('--xlabel', + help = 'x label', + default = 'x') + parser.add_argument('--ylabel', + help = 'y label', + default = 'y') + parser.add_argument('--annotate-maximum', + help = 'annonate maximum value in graph', + action = 'store_true') + parser.add_argument('--annotate-minimum', + help = 'annonate minimum value in graph', + action = 'store_true') + parser.add_argument('--xkcd', + help = 'xkcd style', + action = 'store_true') + + args = parser.parse_args(); + + if (not args.show) and len(args.output) == 0: + raise Exception("at least one of --show and --output/-o must be specified") + + return args + + +def filter_valid_range(points, rect): + """rect = (min_x, max_x, min_y, max_y)""" + ret = [] + for x, y in points: + if x >= rect[0] and x <= rect[1] and y >= rect[2] and y <= rect[3]: + ret.append((x, y)) + if len(ret) == 0: + ret.append(points[0]) + return ret + +def do_plot(data_x, data_y, args): + fig = plt.figure(figsize = (16.18, 10)) + ax = fig.add_axes((0.1, 0.2, 0.8, 0.7)) + plt.scatter(data_x, data_y) +# ax.set_aspect('equal', 'datalim') + #ax.spines['right'].set_color('none') + #ax.spines['left'].set_color('none') + #plt.xticks([]) + #plt.yticks([]) + + if args.annotate_maximum or args.annotate_minimum: + max_x, min_x = max(data_x), min(data_x) + max_y, min_y = max(data_y), min(data_y) + x_range = max_x - min_x + y_range = max_y - min_y + x_max, y_max = data_y[0], data_y[0] + x_min, y_min = data_x[0], data_y[0] + + rect = ax.axis() + + for i in xrange(1, len(data_x)): + if data_y[i] > y_max: + y_max = data_y[i] + x_max = data_x[i] + if data_y[i] < y_min: + y_min = data_y[i] + x_min = data_x[i] + if args.annotate_maximum: + text_x, text_y = filter_valid_range([ + (x_max + 0.05 * x_range, + y_max + 0.025 * y_range), + (x_max - 0.05 * x_range, + y_max + 0.025 * y_range), + (x_max + 0.05 * x_range, + y_max - 0.025 * y_range), + (x_max - 0.05 * x_range, + y_max - 0.025 * y_range)], + rect)[0] + ax.annotate('maximum ({:.3f},{:.3f})' . format(x_max, y_max), + xy = (x_max, y_max), + xytext = (text_x, text_y), + arrowprops = dict(arrowstyle = '->')) + if args.annotate_minimum: + text_x, text_y = filter_valid_range([ + (x_min + 0.05 * x_range, + y_min - 0.025 * y_range), + (x_min - 0.05 * x_range, + y_min - 0.025 * y_range), + (x_min + 0.05 * x_range, + y_min + 0.025 * y_range), + (x_min - 0.05 * x_range, + y_min + 0.025 * y_range)], + rect)[0] + ax.annotate('minimum ({:.3f},{:.3f})' . format(x_min, y_min), + xy = (x_min, y_min), + xytext = (text_x, text_y), + arrowprops = dict(arrowstyle = '->')) + + plt.xlabel(args.xlabel) + plt.ylabel(args.ylabel) + + ax.grid(color = 'gray', linestyle = 'dashed') + + fig.text(0.5, 0.05, args.title, ha = 'center') + if args.output != '': + plt.savefig(args.output) + + if args.show: + plt.show() + +def main(): + args = get_args() + if args.input == stdin_fname: + fin = sys.stdin + else: + fin = open(args.input) + + data_x = [] + data_y = [] + data_format = -1 + for lineno, line in enumerate(fin.readlines()): + line = [float(i) for i in line.rstrip().split()] + line_data_format = -1 + if len(line) == 0: + continue + if len(line) == 2: + line_data_format = 0 + x, y = line + elif len(line) == 1: + line_data_format = 1 + x, y = lineno, line[0] + else: + raise RuntimeError('Can not parse input data at line {}' . format(lineno + 1)) + + if data_format == -1: + data_format = line_data_format + else: + if line_data_format != data_format: + raise RuntimeError('data format is not consistent, at line {}' \ + . format(lineno + 1)) + data_x.append(x) + data_y.append(y) + print len(data_x) + if args.input != stdin_fname: + fin.close() + + if len(data_x) == 1: + return + + if args.xkcd: + with plt.xkcd(): + do_plot(data_x, data_y, args) + else: + do_plot(data_x, data_y, args) + + + +if __name__ == '__main__': + main() diff --git a/src/.gmm.cc.swp b/src/.gmm.cc.swp new file mode 100644 index 0000000000000000000000000000000000000000..c1a19688d7c141a1f00df196d87e0ceab7047603 GIT binary patch literal 20480 zcmeI33y@@0d4R8inpP3T(k%?7#{q`!neKV45Z2zEF$=q{i@VEaSzusTn>*cidM?x5 zx9xkor%V>p)L{mmgpz{6a+*jY} zeE`8Fsh;B7?mqY2bI#cdPBoaWH=EObrz)RQ)#{P%9bFqq zU?hRZA%SkNT-tP!^|Fn#Gt$cGV=qcDeE3r+!l|$UUI;%wiMnt*ybk(Mf!#0( z4?Wwmeg!{-Pr=9FZ{VZwSMcX>6}$#6fk9S(f!{_%?hCJ_k3z8{ksd4QIo1 z;kVDUte?PL@Orowx^N|&3niF<)8IvLDm<98tS`U^;O+1>I0O~g4jbSk_>VubtZ%^Q z;8XBHI0QA=19PwmreG2ZkOvF?=^4l$z6f`~d*K?`1KVH{o(UHG>(h}X{2kl~AA$G6 zyWw4M9b5wE!k@q?@H}|t+f%m~Xp$Ge56PyY^BKh+zxEJn%o8d2@2R=AZg(a|I zBb*A4kQDkyxCg!p?}sYvf>*#6D8tL)LG({_@U|!s#ju9`>x<&+agSs#6ILwjK4-Id%Spm+Zf6@Ak^Jom+QYxM%Aw-I-j_vE7!_$%mg($YTd? z&NfykTDexI5{k{wr@qUj95UK(T2&Kqvt>(*j(4?nc)a=yEW=keZi!2ccqeD1aV7n`OccRi(v<&q!>xaur z&CSU$Q>!dof93KqEiU_Nl%K;QfEDp6PA5= zN)pKghL!wOf1WfGN#D{lIXIrwXySPdYc!L=ge5Y;M&R<1nP(|7GnAR{TDRFY^X*0R zO?%OV|DnzHd9ZeX+|cs6W+lw6kemgin%pMrPHA;vMenMnV@6 zWCo+1oJ^U_&6C0h&LQ!y*`k+d9})m#=tTwZ)vqU$lb*c$#iwvE~H_s$O$m->}UNoU+>Ab<{Ses%B@@?8fpr zGv#wOsf{zUFH@5{(HpopN1=75jMja z@LS^cZ^Pfio8d}009QZ>UIG^Ukof(}@CEohydSQCL$C*?pajo{pAoO$1s{So?12m5 zC2$)2F+4z={#CdUUI#Xu1E;`q;C^EAyWuu?C%hHj3~z!rfW-A@!;9e|;`009KKL$t z2kwTu;BVn(xCSnVZE!xE0jI;s@L$B{zlUGIPvK7ZIJ_H{;ZnE&cESeucj9x2<$nry z!u8OARk#3728rE24|l>Xa5H=iZh|+#Rj?0U0kg0Ho&isT6XEB~-*@4&a2vb_-UzRU zzl3XH0rrEQ>tVt5sB9#Gkpv#M1VSs`bMe~!gRQ{s&sE#)*;%#hHEK-<(_Wiv+MQLW zV@*#?=mbfwUYggj7tt23Qp)qMW38h!mIS#boLd{LM6A2!xb>C5r;4eO%FIMjp2>L; z;0=lGqyst_5chwu@k|fGcq0n&5UChrQZV2Q$B87L*cushP&!C+!fEO5jR&mI;B)p_7a&k=O1_k|uF}P#&w! ztd*t4nT=(gEOd+f(p?B68a+>P}|F^3ZsUd41C` zAC@_hTB!t8)Q4E$rB=7;blhsicsWsav|Mi5{XCr^ZKKv!@Gew-M)b7qkae;?cmGz*tB2=|HW(t!sFLbaJMR{?fN`d1CQ%$et zG?-4gujZ*ke4bu&mPwGRy<0E7XnVdui7XU1sKkrQNT=H1e`HJ*>XGfKtVUk1?pjKP z_4QmYy1bBNn$wGz9F=6i$S_w9R_bogX;u0NX=%+ROco1*v&i!%qphltRF71uA# zQ=d!xtLFi6KSqx{>b_+1W+dqgwc6`U3O;j*pI$`(|vb~;|HprSY_wTVQpBe&GF7R2n*PPi!+c4c18 z7{v>JlU6&7ppT|nG#$1f`tsNmXuXv)ldPu0ZuwnFQ$HHzOh&B7MFT1xb6j$aqdxjl z;A!F8Fm5{|jZ`+Ml`Kl`JM7rCu`bQ06lub&#Av2{gjAZB2c~u>#YX6vjTiEeCqJrL z(((Tb@OWP<`>w?Q(fR(b6YJj#_kisEUjzH#MX&)*gr~w&;G4wzpM~pS0WO9eupPF+ zBgFU*faC-|1D}AK;p4CfTj3YR7 zm6u?_@&lrcs?r|RI^(}1d|Tr^h2KK&D0(9$73TAWDSypw=k@z2s%|_Yd|xMUlpIJ% zjZ&gGi6xS{hxg;wIl(x6vW=fx6QjLWF%~=3_&}VO(-jo<_efD{Xi`mGExO39ESF6Z zo+BHT&_Lp_n;sZ%Md#=TF3SsKjTuO%^5qE?GKhXEshc4)&19AeS<*dc~uAzI8?&$jILIS3<@eCfRyGH%B z58cbifM8^x?+77r@&>L%{&C!W*@}}>G{)(OyyVP1`gyrrmqB)Mal0a$l&*Bks;5_0 zcgZ-}y`rkwfnCunL@MiLCaYI=kWzQ{Nr)Sdz9)K^n2vNMMwhfEjYYkTON&ZxM`asc zw?P%(5KT{JyJ}$AQ5_=ks@xWoz(gkt5OlOA z)*Hpa*fnhp)8Dg}QC4vn?8|u{F|9;aag@ceswmz<-I4=9i7{SZ7M%l9>n&nam60O54Fi`7#0?r+(NTX&7SwqIQ-ZWjp2wR`vl)QL_?(@w6eWawSyL z%&9P{AEi-L?=X$VN}27@(HqSz;RLca}E8Bg`GjA!^jQH|`_tC98$ z46me)_wYoShY#ZL>UFH6n>_BLlU4PF8NLgm6*Oo!Oj{LmSF-0EE+XA%QSH+w64bb<9I4shaU306ikCGcJA9ZIT=)=) zDj5~tYFL0#Kk8}NBfV9y?}L34Ur|4WgP2i_tIbvHyI=vMb-6=QJ!U?UhZt;aNqq=U(UXk */ @@ -59,7 +59,8 @@ real_t Gaussian::log_probability_of(std::vector &x) { for (int i = 0; i < dim; i ++) { real_t &s = sigma[i]; real_t s2 = s * s; - prob += -log(sqrt_2_pi * s) - 1.0 / (2 * s2) * (x[i] - mean[i]); + real_t d = (x[i] - mean[i]); + prob += -log(sqrt_2_pi * s) - 1.0 / (2 * s2) * d * d; } break; case COVTYPE_FULL: @@ -128,7 +129,8 @@ real_t Gaussian::probability_of(std::vector &x) { for (int i = 0; i < dim; i ++) { real_t &s = sigma[i]; real_t d = x[i] - mean[i]; - prob *= exp(- d * d / (2 * s * s)) / (sqrt_2_pi * s); + real_t p = exp(- d * d / (2 * s * s)) / (sqrt_2_pi * s); + prob *= p; } break; case COVTYPE_FULL: @@ -159,8 +161,24 @@ real_t GMM::log_probability_of(std::vector &x, int mixture_id) { real_t GMM::log_probability_of(std::vector &x) { real_t prob = 0; - for (auto &g: gaussians) - prob += g->log_probability_of(x); + for (int i = 0; i < nr_mixtures; i ++) { + prob += weights[i] * gaussians[i]->probability_of(x); + } + return log(prob); +} + +real_t GMM::probability_of(std::vector &x) { + real_t prob = 0; + for (int i = 0; i < nr_mixtures; i ++) { + prob *= weights[i] * gaussians[i]->probability_of(x); + } + return prob; +} + +real_t GMM::log_probability_of(std::vector> &X) { + real_t prob = 0; + for (auto &x: X) + prob += log_probability_of(x); return prob; } @@ -251,8 +269,16 @@ void GMMTrainerBaseline::init_gaussians(std::vector> &X) { gmm->weights.resize(gmm->nr_mixtures); for (auto &w: gmm->weights) w = random.rand_real(); + gmm->normalize_weights(); } +void GMM::normalize_weights() { + real_t w_sum = 0; + for (auto &w: weights) + w_sum += w; + for (auto &w: weights) + w /= w_sum; +} void GMMTrainerBaseline::clear_gaussians() { for (auto &g: gmm->gaussians) @@ -272,42 +298,59 @@ static void gassian_set_zero(Gaussian *gaussian) { } void GMMTrainerBaseline::iteration(std::vector> &X) { - size_t n = X.size(); - - std::vector mixture_density(gmm->nr_mixtures); - for (int i = 0; i < gmm->nr_mixtures; i ++) { - real_t &md = mixture_density[i] = 0; - std::vector mu(dim); - auto &prob = prob_of_y_given_x[i]; - for (size_t j = 0; j < n; j ++) { - prob[j] = gmm->gaussians[i]->probability_of(X[j]); - md += prob[j]; - } - gmm->weights[i] = md / n; + int n = (int)X.size(); + + for (int k = 0; k < gmm->nr_mixtures; k ++) + for (int i = 0; i < n; i ++) + prob_of_y_given_x[k][i] = gmm->weights[k] * gmm->gaussians[k]->probability_of(X[i]); + + for (int i = 0; i < n; i ++) { + real_t prob_sum = 0; + for (int k = 0; k < gmm->nr_mixtures; k ++) + prob_sum += prob_of_y_given_x[k][i]; + assert(prob_sum > 0); + for (int k = 0; k < gmm->nr_mixtures; k ++) + prob_of_y_given_x[k][i] /= prob_sum; + } + + for (int k = 0; k < gmm->nr_mixtures; k ++) { + N_k[k] = 0; + for (int i = 0; i < n; i ++) + N_k[k] += prob_of_y_given_x[k][i]; + assert(N_k[k] > 0); } - for (int i = 0; i < gmm->nr_mixtures; i ++) { - auto &gaussian = gmm->gaussians[i]; + for (auto &gaussian: gmm->gaussians) gassian_set_zero(gaussian); - auto &prob = prob_of_y_given_x[i]; - for (size_t j = 0; j < n; j ++) { - auto &x = X[j]; - for (int k = 0; k < dim; k ++) - gaussian->mean[k] += x[k] * prob[j]; + + for (int k = 0; k < gmm->nr_mixtures; k ++) + gmm->weights[k] = N_k[k] / n; + + vector tmp(dim); + for (int k = 0; k < gmm->nr_mixtures; k ++) { + auto &gaussian = gmm->gaussians[k]; + for (int i = 0; i < n; i ++) { + mult(X[i], prob_of_y_given_x[k][i], tmp); + add_self(gaussian->mean, tmp); } - mult_self(gaussian->mean, 1.0 / mixture_density[i]); + mult_self(gaussian->mean, 1.0 / N_k[k]); + } - for (size_t j = 0; j < n; j ++) { - auto &x = X[j]; - for (int k = 0; k < dim; k ++) { - real_t d = x[k] - gaussian->mean[k]; - gaussian->sigma[k] += d * d * prob[j]; - } + for (int k = 0; k < gmm->nr_mixtures; k ++) { + auto &gaussian = gmm->gaussians[k]; + for (int i = 0; i < n; i ++) { + sub(X[i], gaussian->mean, tmp); + for (auto &t: tmp) t = t * t; + mult_self(tmp, prob_of_y_given_x[k][i]); + add_self(gaussian->sigma, tmp); } - mult_self(gaussian->sigma, 1.0 / mixture_density[i]); - for (auto &s: gaussian->sigma) + mult_self(gaussian->sigma, 1.0 / N_k[k]); + for (auto &s: gaussian->sigma) { s = sqrt(s); + s = max(sqrt(min_covar), s); + } } + } void GMMTrainerBaseline::train(GMM *gmm, std::vector> &X) { @@ -325,18 +368,34 @@ void GMMTrainerBaseline::train(GMM *gmm, std::vector> &X) { for (auto &v: prob_of_y_given_x) v.resize(X.size()); + N_k.resize(gmm->nr_mixtures); + clear_gaussians(); init_gaussians(X); +#define PAUSE() \ + do { \ + ofstream out("gmm-test.model"); \ + gmm->dump(out); \ + gmm->dump(cout); \ + printf("press a key to continue ...\n"); \ + getchar(); \ + } while (0) + + real_t last_ll = -numeric_limits::max(); for (int i = 0; i < nr_iter; i ++) { iteration(X); // monitor average log likelihood - real_t ll = 0; - for (auto &x: X) - ll += gmm->log_probability_of(x); - ll /= X.size(); + real_t ll = gmm->log_probability_of(X); printf("iter %d: ll %lf\n", i, ll); + + real_t ll_diff = ll - last_ll; + if (fabs(ll_diff) / fabs(ll) < 1e-8 && ll_diff < 1e-8) { + printf("too small log likelihood increment, abort iteration.\n"); + break; + } + last_ll = ll; } } diff --git a/src/gmm.hh b/src/gmm.hh index 7845b1e..3184273 100644 --- a/src/gmm.hh +++ b/src/gmm.hh @@ -1,6 +1,6 @@ /* * $File: gmm.hh - * $Date: Mon Dec 09 00:45:03 2013 +0800 + * $Date: Tue Dec 10 12:42:53 2013 +0800 * $Author: Xinyu Zhou */ @@ -70,6 +70,7 @@ class GMMTrainerBaseline : public GMMTrainer { real_t min_covar; std::vector> prob_of_y_given_x; // y, x + std::vector N_k; }; class GMM { @@ -103,6 +104,10 @@ class GMM { real_t log_probability_of(std::vector &x, int mixture_id); real_t log_probability_of(std::vector &x); + real_t log_probability_of(std::vector> &X); + real_t probability_of(std::vector &x); + + void normalize_weights(); }; /** diff --git a/src/main.cc b/src/main.cc index 5452bee..16b6491 100644 --- a/src/main.cc +++ b/src/main.cc @@ -1,6 +1,6 @@ /* * $File: main.cc - * $Date: Mon Dec 09 00:42:17 2013 +0800 + * $Date: Tue Dec 10 12:50:34 2013 +0800 * $Author: Xinyu Zhou */ @@ -37,7 +37,7 @@ vector string_to_double_vector(string line) { while (end < len && line[end] != ' ' && line[end] != '\n') end ++; x.push_back(atof(line.substr(begin, end - begin).c_str())); - if (end == len || line[end] == '\n') + if (end == len - 1 || line[end] == '\n' || (end == len - 2 && line[end] == ' ' && line[end] == '\n')) break; begin = end + 1; end = begin; @@ -116,13 +116,7 @@ void fill_gaussian_2d(DenseDataset &X, Gaussian *gaussian, int nr_point) { X.push_back(gaussian->sample()); } -int main(int argc, char *argv[]) { -// srand(42); // Answer to The Ultimate Question of Life, the Universe, and Everything -// Args args = parse_args(argc, argv); - - DenseDataset X; -// read_dense_dataset(X, args.input_file.c_str()); - +void gen_gaussian_mixture(DenseDataset &X) { int nr_gaussian = 2; int nr_point_per_gaussian = 100; Gaussian g0(2); @@ -130,16 +124,31 @@ int main(int argc, char *argv[]) { g0.sigma = {0.1, 0.1}; Gaussian g1(2); - g1.mean = {1, 0}; + g1.mean = {1, 1}; g1.sigma = {0.1, 0.1}; + Gaussian g2(2); + g2.mean = {2, 1}; + g2.sigma = {0.2, 0.2}; + fill_gaussian_2d(X, &g0, nr_point_per_gaussian); fill_gaussian_2d(X, &g1, nr_point_per_gaussian); + fill_gaussian_2d(X, &g2, nr_point_per_gaussian); +} + +int main(int argc, char *argv[]) { + srand(42); // Answer to The Ultimate Question of Life, the Universe, and Everything +// Args args = parse_args(argc, argv); + + DenseDataset X; +// read_dense_dataset(X, "test.data"); + gen_gaussian_mixture(X); + int nr_mixture = 3; write_dense_dataset(X, "test.data"); - int nr_mixture = 2; - GMM gmm(nr_mixture); + GMMTrainerBaseline trainer(1000); + GMM gmm(nr_mixture, COVTYPE_DIAGONAL, &trainer); gmm.fit(X); ofstream fout("gmm-test.model"); diff --git a/src/random.hh b/src/random.hh index 9a53ff4..fdd9729 100644 --- a/src/random.hh +++ b/src/random.hh @@ -1,6 +1,6 @@ /* * $File: random.hh - * $Date: Fri Dec 06 14:29:49 2013 +0800 + * $Date: Tue Dec 10 12:44:17 2013 +0800 * $Author: Xinyu Zhou */ @@ -20,6 +20,7 @@ class Random { Random() { long long seed = std::chrono::system_clock::now().time_since_epoch().count(); +// seed = rand(); generator.seed(seed); }