-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
657 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#!/usr/bin/python2 | ||
# -*- coding: utf-8 -*- | ||
# $File: gmm.py | ||
# $Date: Tue Dec 10 11:36:41 2013 +0800 | ||
# $Author: Xinyu Zhou <zxytim[at]gmail[dot]com> | ||
|
||
from sklearn.mixture import GMM | ||
|
||
|
||
def read_data(fname): | ||
with open(fname) as fin: | ||
return map(lambda line: map(float, line.rstrip().split()), fin) | ||
|
||
def dump_gmm(gmm): | ||
print gmm.n_components | ||
print " " . join(map(str, gmm.weights_)) | ||
for i in range(gmm.n_components): | ||
print len(gmm.means_[i]), 1 | ||
print " " . join(map(str, gmm.means_[i])) | ||
print " " . join(map(str, gmm.covars_[i])) | ||
|
||
def main(): | ||
gmm = GMM(3) | ||
X = read_data('test.data') | ||
gmm.fit(X) | ||
dump_gmm(gmm) | ||
|
||
if __name__ == '__main__': | ||
main() | ||
|
||
|
||
# vim: foldmethod=marker |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
#!/usr/bin/python2 | ||
# -*- coding: utf-8 -*- | ||
# $File: plot-gmm.py | ||
# $Date: Tue Dec 10 12:45:04 2013 +0800 | ||
# $Author: Xinyu Zhou <zxytim[at]gmail[dot]com> | ||
|
||
import matplotlib.pyplot as plt | ||
import matplotlib.mlab as mlab | ||
from matplotlib import cm | ||
from scipy import stats, mgrid, c_, reshape, random, rot90 | ||
import argparse | ||
from numpy import * | ||
import numpy as np | ||
|
||
class GassianTypeNotImplemented(Exception): | ||
pass | ||
|
||
def get_args(): | ||
description = 'plot gmm' | ||
parser = argparse.ArgumentParser(description = description) | ||
|
||
parser.add_argument('-i', '--input', help = 'data file', required = True) | ||
parser.add_argument('-m', '--model', help = 'model file', required = True) | ||
|
||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
class Gaussian(object): | ||
def __init__(self): | ||
self.covtype = 1 | ||
self.dim = 0 | ||
self.mean = array([]) | ||
self.sigma = array([]) | ||
self.covariance = array([[]]) | ||
|
||
def probability_of(self, x): | ||
assert len(x) == self.dim | ||
|
||
return exp((x - mean)**2 / (2 * self.sigma**2)) / (2 * pi * self.sigma) | ||
|
||
|
||
class GMM(object): | ||
def __init__(self): | ||
self.nr_mixtures = 0 | ||
self.weights = array([]) | ||
self.gaussians = [] | ||
|
||
def read_data(fname): | ||
with open(fname) as fin: | ||
return zip(*map( lambda line: map(float, line.rstrip().split()), fin)) | ||
|
||
|
||
def read_gaussian(fin): | ||
gaussian = Gaussian() | ||
gaussian.dim, gaussian.covtype = map(int, fin.readline().rstrip().split()) | ||
if gaussian.covtype == 1: | ||
gaussian.mean = map(float, fin.readline().rstrip().split()) | ||
gaussian.sigma = map(float, fin.readline().rstrip().split()) | ||
assert len(gaussian.mean) == gaussian.dim | ||
assert len(gaussian.sigma) == gaussian.dim | ||
else: | ||
raise GassianTypeNotImplemented() | ||
return gaussian | ||
|
||
def read_model(fname): | ||
gmm = GMM() | ||
with open(fname) as fin: | ||
gmm.nr_mixtures = int(fin.readline().rstrip()) | ||
gmm.weights = map(float, fin.readline().rstrip().split()) | ||
for i in range(gmm.nr_mixtures): | ||
gmm.gaussians.append(read_gaussian(fin)) | ||
|
||
return gmm | ||
|
||
def main(): | ||
args = get_args() | ||
data = read_data(args.input) | ||
x, y = data[:2] | ||
gmm = read_model(args.model) | ||
|
||
fig = plt.figure() | ||
ax = fig.add_subplot(111, aspect = 'equal') | ||
ax.scatter(x, y) | ||
x0, x1, y0, y1 = ax.axis() | ||
|
||
x = linspace(x0, x1, 1000) | ||
y = linspace(y0, y1, 1000) | ||
X, Y = meshgrid(x, y) | ||
|
||
def get_Z(X, Y, gaussian): | ||
return mlab.bivariate_normal(X, Y, gaussian.sigma[0], gaussian.sigma[1], | ||
gaussian.mean[0], gaussian.mean[1], 0) | ||
|
||
Z = get_Z(X, Y, gmm.gaussians[0]) | ||
for gaussian in gmm.gaussians[1:]: | ||
Z += get_Z(X, Y, gaussian) | ||
plt.contour(X, Y, Z, cmap=cm.PuBu_r) | ||
for gaussian in gmm.gaussians: | ||
print gaussian.mean | ||
plt.scatter(gaussian.mean[0], gaussian.mean[1], s = 50, c = 'yellow') | ||
|
||
plt.show() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() | ||
|
||
|
||
# vim: foldmethod=marker | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
#!/usr/bin/python2 | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from mpl_toolkits.mplot3d import Axes3D | ||
import argparse, sys | ||
|
||
stdin_fname = '$stdin$' | ||
|
||
def get_args(): | ||
description = "plot points into graph. x and y seperated with white space in one line, or just y's" | ||
parser = argparse.ArgumentParser(description = description) | ||
parser.add_argument('-i', '--input', | ||
help = 'input data file, "-" for stdin, default stdin', | ||
default = '-') | ||
parser.add_argument('-o', '--output', | ||
help = 'output image', default = '') | ||
parser.add_argument('--show', | ||
help = 'show the figure after rendered', | ||
action = 'store_true') | ||
parser.add_argument('-t', '--title', | ||
help = 'title of the graph', | ||
default = '') | ||
parser.add_argument('--xlabel', | ||
help = 'x label', | ||
default = 'x') | ||
parser.add_argument('--ylabel', | ||
help = 'y label', | ||
default = 'y') | ||
parser.add_argument('--zlabel', | ||
help = 'z label', | ||
default = 'z') | ||
parser.add_argument('--xlim', help = 'xlim') | ||
parser.add_argument('--ylim', help = 'ylim') | ||
parser.add_argument('--zlim', help = 'zlim') | ||
|
||
parser.add_argument('--annotate-maximum', | ||
help = 'annonate maximum value in graph', | ||
action = 'store_true') | ||
parser.add_argument('--annotate-minimum', | ||
help = 'annonate minimum value in graph', | ||
action = 'store_true') | ||
parser.add_argument('--xkcd', | ||
help = 'xkcd style', | ||
action = 'store_true') | ||
|
||
args = parser.parse_args(); | ||
|
||
if (not args.show) and len(args.output) == 0: | ||
raise Exception("at least one of --show and --output/-o must be specified") | ||
if args.xlim: | ||
args.xlim = map(float, args.xlim.rstrip().split(',')) | ||
if args.ylim: | ||
args.ylim = map(float, args.ylim.rstrip().split(',')) | ||
if args.zlim: | ||
args.zlim = map(float, args.zlim.rstrip().split(',')) | ||
|
||
return args | ||
|
||
|
||
def filter_valid_range(points, rect): | ||
"""rect = (min_x, max_x, min_y, max_y)""" | ||
ret = [] | ||
for x, y in points: | ||
if x >= rect[0] and x <= rect[1] and y >= rect[2] and y <= rect[3]: | ||
ret.append((x, y)) | ||
if len(ret) == 0: | ||
ret.append(points[0]) | ||
return ret | ||
|
||
def do_plot(data_x, data_y, data_z, args): | ||
fig = plt.figure(figsize = (16.18, 10)) | ||
projection = '2d' | ||
if len(data_z) > 0: | ||
projection = '3d' | ||
ax = fig.add_axes((0.1, 0.2, 0.8, 0.7), projection = projection) | ||
if projection == '2d': | ||
ax.scatter(data_x, data_y) | ||
else: | ||
ax.scatter(data_x, data_y, data_z) | ||
if args.xlim: | ||
ax.set_xlim(args.xlim) | ||
if args.ylim: | ||
ax.set_ylim(args.ylim) | ||
if args.zlim: | ||
ax.set_zlim3d(args.zlim) | ||
if args.xlim or args.ylim or args.zlim: | ||
pass | ||
ax.set_aspect('equal') | ||
else: | ||
ax.set_aspect('equal', 'datalim') | ||
#ax.spines['right'].set_color('none') | ||
#ax.spines['left'].set_color('none') | ||
#plt.xticks([]) | ||
#plt.yticks([]) | ||
|
||
if args.annotate_maximum or args.annotate_minimum: | ||
max_x, min_x = max(data_x), min(data_x) | ||
max_y, min_y = max(data_y), min(data_y) | ||
x_range = max_x - min_x | ||
y_range = max_y - min_y | ||
x_max, y_max = data_y[0], data_y[0] | ||
x_min, y_min = data_x[0], data_y[0] | ||
|
||
rect = ax.axis() | ||
|
||
for i in xrange(1, len(data_x)): | ||
if data_y[i] > y_max: | ||
y_max = data_y[i] | ||
x_max = data_x[i] | ||
if data_y[i] < y_min: | ||
y_min = data_y[i] | ||
x_min = data_x[i] | ||
if args.annotate_maximum: | ||
text_x, text_y = filter_valid_range([ | ||
(x_max + 0.05 * x_range, | ||
y_max + 0.025 * y_range), | ||
(x_max - 0.05 * x_range, | ||
y_max + 0.025 * y_range), | ||
(x_max + 0.05 * x_range, | ||
y_max - 0.025 * y_range), | ||
(x_max - 0.05 * x_range, | ||
y_max - 0.025 * y_range)], | ||
rect)[0] | ||
ax.annotate('maximum ({:.3f},{:.3f})' . format(x_max, y_max), | ||
xy = (x_max, y_max), | ||
xytext = (text_x, text_y), | ||
arrowprops = dict(arrowstyle = '->')) | ||
if args.annotate_minimum: | ||
text_x, text_y = filter_valid_range([ | ||
(x_min + 0.05 * x_range, | ||
y_min - 0.025 * y_range), | ||
(x_min - 0.05 * x_range, | ||
y_min - 0.025 * y_range), | ||
(x_min + 0.05 * x_range, | ||
y_min + 0.025 * y_range), | ||
(x_min - 0.05 * x_range, | ||
y_min + 0.025 * y_range)], | ||
rect)[0] | ||
ax.annotate('minimum ({:.3f},{:.3f})' . format(x_min, y_min), | ||
xy = (x_min, y_min), | ||
xytext = (text_x, text_y), | ||
arrowprops = dict(arrowstyle = '->')) | ||
|
||
ax.set_xlabel(args.xlabel) | ||
ax.set_ylabel(args.ylabel) | ||
if projection == '3d': | ||
ax.set_zlabel(args.zlabel) | ||
|
||
fig.text(0.5, 0.05, args.title, ha = 'center') | ||
if args.output != '': | ||
plt.savefig(args.output) | ||
|
||
if args.show: | ||
plt.show() | ||
|
||
def main(): | ||
args = get_args() | ||
if args.input == stdin_fname: | ||
fin = sys.stdin | ||
else: | ||
fin = open(args.input) | ||
|
||
data_x = [] | ||
data_y = [] | ||
data_z = [] | ||
data_format = -1 | ||
for lineno, line in enumerate(fin.readlines()): | ||
line = [float(i) for i in line.rstrip().split()] | ||
line_data_format = -1 | ||
x, y, z = None, None, None | ||
if len(line) == 0: | ||
continue | ||
if len(line) == 2: | ||
line_data_format = 0 | ||
x, y = line | ||
elif len(line) == 1: | ||
line_data_format = 1 | ||
x, y = lineno, line[0] | ||
elif len(line) == 3: | ||
x, y, z = line | ||
line_data_format = 2; | ||
else: | ||
raise RuntimeError('Can not parse input data at line {}' . format(lineno + 1)) | ||
|
||
if data_format == -1: | ||
data_format = line_data_format | ||
else: | ||
if line_data_format != data_format: | ||
raise RuntimeError('data format is not consistent, at line {}' \ | ||
. format(lineno + 1)) | ||
data_x.append(x) | ||
data_y.append(y) | ||
if z != None: | ||
data_z.append(z) | ||
print len(data_x) | ||
if args.input != stdin_fname: | ||
fin.close() | ||
|
||
if len(data_x) == 1: | ||
return | ||
|
||
if args.xkcd: | ||
with plt.xkcd(): | ||
do_plot(data_x, data_y, data_z, args) | ||
else: | ||
do_plot(data_x, data_y, data_z, args) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.