Skip to content

Commit

Permalink
basic gmm worked
Browse files Browse the repository at this point in the history
  • Loading branch information
zxytim committed Dec 10, 2013
1 parent 65af7b0 commit 834d90c
Show file tree
Hide file tree
Showing 10 changed files with 657 additions and 52 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# $File: Makefile
# $Date: Mon Dec 09 00:40:18 2013 +0800
# $Date: Tue Dec 10 11:57:58 2013 +0800
#
# A single output portable Makefile for
# simple c++ project
Expand Down
32 changes: 32 additions & 0 deletions gmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/python2
# -*- coding: utf-8 -*-
# $File: gmm.py
# $Date: Tue Dec 10 11:36:41 2013 +0800
# $Author: Xinyu Zhou <zxytim[at]gmail[dot]com>

from sklearn.mixture import GMM


def read_data(fname):
with open(fname) as fin:
return map(lambda line: map(float, line.rstrip().split()), fin)

def dump_gmm(gmm):
print gmm.n_components
print " " . join(map(str, gmm.weights_))
for i in range(gmm.n_components):
print len(gmm.means_[i]), 1
print " " . join(map(str, gmm.means_[i]))
print " " . join(map(str, gmm.covars_[i]))

def main():
gmm = GMM(3)
X = read_data('test.data')
gmm.fit(X)
dump_gmm(gmm)

if __name__ == '__main__':
main()


# vim: foldmethod=marker
112 changes: 112 additions & 0 deletions plot-gmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#!/usr/bin/python2
# -*- coding: utf-8 -*-
# $File: plot-gmm.py
# $Date: Tue Dec 10 12:45:04 2013 +0800
# $Author: Xinyu Zhou <zxytim[at]gmail[dot]com>

import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
from matplotlib import cm
from scipy import stats, mgrid, c_, reshape, random, rot90
import argparse
from numpy import *
import numpy as np

class GassianTypeNotImplemented(Exception):
pass

def get_args():
description = 'plot gmm'
parser = argparse.ArgumentParser(description = description)

parser.add_argument('-i', '--input', help = 'data file', required = True)
parser.add_argument('-m', '--model', help = 'model file', required = True)

args = parser.parse_args()

return args


class Gaussian(object):
def __init__(self):
self.covtype = 1
self.dim = 0
self.mean = array([])
self.sigma = array([])
self.covariance = array([[]])

def probability_of(self, x):
assert len(x) == self.dim

return exp((x - mean)**2 / (2 * self.sigma**2)) / (2 * pi * self.sigma)


class GMM(object):
def __init__(self):
self.nr_mixtures = 0
self.weights = array([])
self.gaussians = []

def read_data(fname):
with open(fname) as fin:
return zip(*map( lambda line: map(float, line.rstrip().split()), fin))


def read_gaussian(fin):
gaussian = Gaussian()
gaussian.dim, gaussian.covtype = map(int, fin.readline().rstrip().split())
if gaussian.covtype == 1:
gaussian.mean = map(float, fin.readline().rstrip().split())
gaussian.sigma = map(float, fin.readline().rstrip().split())
assert len(gaussian.mean) == gaussian.dim
assert len(gaussian.sigma) == gaussian.dim
else:
raise GassianTypeNotImplemented()
return gaussian

def read_model(fname):
gmm = GMM()
with open(fname) as fin:
gmm.nr_mixtures = int(fin.readline().rstrip())
gmm.weights = map(float, fin.readline().rstrip().split())
for i in range(gmm.nr_mixtures):
gmm.gaussians.append(read_gaussian(fin))

return gmm

def main():
args = get_args()
data = read_data(args.input)
x, y = data[:2]
gmm = read_model(args.model)

fig = plt.figure()
ax = fig.add_subplot(111, aspect = 'equal')
ax.scatter(x, y)
x0, x1, y0, y1 = ax.axis()

x = linspace(x0, x1, 1000)
y = linspace(y0, y1, 1000)
X, Y = meshgrid(x, y)

def get_Z(X, Y, gaussian):
return mlab.bivariate_normal(X, Y, gaussian.sigma[0], gaussian.sigma[1],
gaussian.mean[0], gaussian.mean[1], 0)

Z = get_Z(X, Y, gmm.gaussians[0])
for gaussian in gmm.gaussians[1:]:
Z += get_Z(X, Y, gaussian)
plt.contour(X, Y, Z, cmap=cm.PuBu_r)
for gaussian in gmm.gaussians:
print gaussian.mean
plt.scatter(gaussian.mean[0], gaussian.mean[1], s = 50, c = 'yellow')

plt.show()


if __name__ == '__main__':
main()


# vim: foldmethod=marker

212 changes: 212 additions & 0 deletions plot-point-3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
#!/usr/bin/python2

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import argparse, sys

stdin_fname = '$stdin$'

def get_args():
description = "plot points into graph. x and y seperated with white space in one line, or just y's"
parser = argparse.ArgumentParser(description = description)
parser.add_argument('-i', '--input',
help = 'input data file, "-" for stdin, default stdin',
default = '-')
parser.add_argument('-o', '--output',
help = 'output image', default = '')
parser.add_argument('--show',
help = 'show the figure after rendered',
action = 'store_true')
parser.add_argument('-t', '--title',
help = 'title of the graph',
default = '')
parser.add_argument('--xlabel',
help = 'x label',
default = 'x')
parser.add_argument('--ylabel',
help = 'y label',
default = 'y')
parser.add_argument('--zlabel',
help = 'z label',
default = 'z')
parser.add_argument('--xlim', help = 'xlim')
parser.add_argument('--ylim', help = 'ylim')
parser.add_argument('--zlim', help = 'zlim')

parser.add_argument('--annotate-maximum',
help = 'annonate maximum value in graph',
action = 'store_true')
parser.add_argument('--annotate-minimum',
help = 'annonate minimum value in graph',
action = 'store_true')
parser.add_argument('--xkcd',
help = 'xkcd style',
action = 'store_true')

args = parser.parse_args();

if (not args.show) and len(args.output) == 0:
raise Exception("at least one of --show and --output/-o must be specified")
if args.xlim:
args.xlim = map(float, args.xlim.rstrip().split(','))
if args.ylim:
args.ylim = map(float, args.ylim.rstrip().split(','))
if args.zlim:
args.zlim = map(float, args.zlim.rstrip().split(','))

return args


def filter_valid_range(points, rect):
"""rect = (min_x, max_x, min_y, max_y)"""
ret = []
for x, y in points:
if x >= rect[0] and x <= rect[1] and y >= rect[2] and y <= rect[3]:
ret.append((x, y))
if len(ret) == 0:
ret.append(points[0])
return ret

def do_plot(data_x, data_y, data_z, args):
fig = plt.figure(figsize = (16.18, 10))
projection = '2d'
if len(data_z) > 0:
projection = '3d'
ax = fig.add_axes((0.1, 0.2, 0.8, 0.7), projection = projection)
if projection == '2d':
ax.scatter(data_x, data_y)
else:
ax.scatter(data_x, data_y, data_z)
if args.xlim:
ax.set_xlim(args.xlim)
if args.ylim:
ax.set_ylim(args.ylim)
if args.zlim:
ax.set_zlim3d(args.zlim)
if args.xlim or args.ylim or args.zlim:
pass
ax.set_aspect('equal')
else:
ax.set_aspect('equal', 'datalim')
#ax.spines['right'].set_color('none')
#ax.spines['left'].set_color('none')
#plt.xticks([])
#plt.yticks([])

if args.annotate_maximum or args.annotate_minimum:
max_x, min_x = max(data_x), min(data_x)
max_y, min_y = max(data_y), min(data_y)
x_range = max_x - min_x
y_range = max_y - min_y
x_max, y_max = data_y[0], data_y[0]
x_min, y_min = data_x[0], data_y[0]

rect = ax.axis()

for i in xrange(1, len(data_x)):
if data_y[i] > y_max:
y_max = data_y[i]
x_max = data_x[i]
if data_y[i] < y_min:
y_min = data_y[i]
x_min = data_x[i]
if args.annotate_maximum:
text_x, text_y = filter_valid_range([
(x_max + 0.05 * x_range,
y_max + 0.025 * y_range),
(x_max - 0.05 * x_range,
y_max + 0.025 * y_range),
(x_max + 0.05 * x_range,
y_max - 0.025 * y_range),
(x_max - 0.05 * x_range,
y_max - 0.025 * y_range)],
rect)[0]
ax.annotate('maximum ({:.3f},{:.3f})' . format(x_max, y_max),
xy = (x_max, y_max),
xytext = (text_x, text_y),
arrowprops = dict(arrowstyle = '->'))
if args.annotate_minimum:
text_x, text_y = filter_valid_range([
(x_min + 0.05 * x_range,
y_min - 0.025 * y_range),
(x_min - 0.05 * x_range,
y_min - 0.025 * y_range),
(x_min + 0.05 * x_range,
y_min + 0.025 * y_range),
(x_min - 0.05 * x_range,
y_min + 0.025 * y_range)],
rect)[0]
ax.annotate('minimum ({:.3f},{:.3f})' . format(x_min, y_min),
xy = (x_min, y_min),
xytext = (text_x, text_y),
arrowprops = dict(arrowstyle = '->'))

ax.set_xlabel(args.xlabel)
ax.set_ylabel(args.ylabel)
if projection == '3d':
ax.set_zlabel(args.zlabel)

fig.text(0.5, 0.05, args.title, ha = 'center')
if args.output != '':
plt.savefig(args.output)

if args.show:
plt.show()

def main():
args = get_args()
if args.input == stdin_fname:
fin = sys.stdin
else:
fin = open(args.input)

data_x = []
data_y = []
data_z = []
data_format = -1
for lineno, line in enumerate(fin.readlines()):
line = [float(i) for i in line.rstrip().split()]
line_data_format = -1
x, y, z = None, None, None
if len(line) == 0:
continue
if len(line) == 2:
line_data_format = 0
x, y = line
elif len(line) == 1:
line_data_format = 1
x, y = lineno, line[0]
elif len(line) == 3:
x, y, z = line
line_data_format = 2;
else:
raise RuntimeError('Can not parse input data at line {}' . format(lineno + 1))

if data_format == -1:
data_format = line_data_format
else:
if line_data_format != data_format:
raise RuntimeError('data format is not consistent, at line {}' \
. format(lineno + 1))
data_x.append(x)
data_y.append(y)
if z != None:
data_z.append(z)
print len(data_x)
if args.input != stdin_fname:
fin.close()

if len(data_x) == 1:
return

if args.xkcd:
with plt.xkcd():
do_plot(data_x, data_y, data_z, args)
else:
do_plot(data_x, data_y, data_z, args)



if __name__ == '__main__':
main()
Loading

0 comments on commit 834d90c

Please sign in to comment.