Sunday, January 20, 2008

Mixture of Gaussians and Expectation Maximization algorithm

Recently I was implementing expectation-maximization (EM) algorithm to solve problem of fuzzy data clustering. Consider dataset consisting of n points (x,y):

d=((x1,y1), (x2,y2), ..., (xn,yn))

In non-fuzzy (hard) clustering each point is assigned to exactly one cluster. Fuzzy clustering can assign points to more than one cluster by giving membership grades, which indicate the degree to which the data points belong to the different clusters.

I was solving a problem where points d are assumed to be generated by mixture of K Gaussian distributions. In other words, distribution of points in d is given by density:

p(x,y)= w1 p1(x,y)+ ... + wK pK(x,y)

where w1+...+wK=1 and

pj(x,y)= 1/(2 π sxjsyj) exp{-½[(x-mxj)/sxj]2 -½[(y-myj)/syj]2}

Note that each dimension has it's own standard deviation: sx for x and sy for y. There are also problems and solutions that use one standard deviation for both dimensions, resulting in clusters that have "shape" of circles (instead of ellipses).

Following variables are unknown (latent variables): wj, mxj, myj, sxj, syj. To cluster dataset into K fuzzy clusters, we must find those variables for each cluster and minimize following function:

L(d)= -ln(p(x1,y1)) - ... -ln(p(xn,yn))

This is where Expectation-Maximization algorithm is used. After finding latent variables it is easy to compute probability of membership to cluster j of arbitrary point (xi,yi):

wj pj (xi, yi) / p (xi, yi)

More information:

Below is the code. Input and output format is described inside it. Input data is read from standard input.

1 import java.io.*;
2 import java.math.*;
3 import java.util.*;
4
5 // w - coefficients of mixture (weights of gaussian distributions)
6 // m - means
7 // s - standard deviations stored as pairs, one field is assigned for
8 // x1, and one for x2
9 //
10 // d=((x_1,y_1),..., (x_n,y_n))
11 // p(x,y)= w_1 p_1(x,y)+ ... + w_K p_K(x,y)
12 // w_1+...+w_K=1
13 // p_j(x,y)= 1/(2 pi s_xj s_yj) exp{-[(x-m_xj )/s_xj]^2/2 -[(y-m_yj )/s_yj ]^2/2}.
14 //
15 // unknown variables: w_j, m_xj, m_yj, s_xj, s_yj
16 //
17 // minimalised function: L(d)= -ln(p(x_1,y_1)) - ... -ln(p(x_n,y_n))
18 //
19 // Input format:
20 // K n
21 // x_1 y_1
22 // x_2 y_2
23 // ...
24 // x_n y_n
25 //
26 // Output format:
27 // L(d)
28 // w_1 mx_1 sx_1 my_1 sy_1
29 // w_2 mx_2 sx_2 my_2 sy_2
30 // ...
31 // w_n mx_n sx_n my_n sy_n
32 //
33
34 public class EM {
35 // Pairs (x,y)
36 double[][] points;
37 int K,N;
38 double L;
39 double w[], w_prev[];
40 double m[][], m_prev[][];
41 double s[][], s_prev[][];
42 double P_x[], P_x_prev[];
43 double Min[], Max[];
44
45 EM(double[][] points, int K) throws Exception {
46 init(points);
47 setM(K);
48 }
49
50 private void init(double[][] _points) throws Exception {
51 points = _points;
52 N = points.length;
53 // Pairs of points required
54 if(points[0].length != 2) {
55 throw new Exception();
56 }
57 Min = new double[2];
58 Max = new double[2];
59 int n,i,j;
60 for(i=0;i<2;i++) {
61 Min[i] = points[0][i];
62 Max[i] = points[0][i];
63 for(n=1;n<N;n++) {
64 if(Min[i] > points[n][i]) Min[i] = points[n][i]; else
65 if(Max[i] < points[n][i]) Max[i] = points[n][i];
66 }
67 }
68 }
69
70 private void setM(int _K) {
71 int i,j;
72 double var[] = {0.0, 0.0};
73
74 K = _K;
75 P_x = new double[N];
76 P_x_prev = new double[N];
77 w = new double[K];
78 m = new double[K][2];
79 s = new double[K][2];
80 w_prev = new double[K];
81 m_prev = new double[K][2];
82 s_prev = new double[K][2];
83
84 var[0] = (Max[0] - Min[0]) * (Max[0] - Min[0]) / 12;
85 var[1] = (Max[1] - Min[1]) * (Max[1] - Min[1]) / 12;
86
87 for(j=0;j<K;j++) {
88 w[j] = 1.0 / (double)K;
89 s[j][0] = var[0] / 4 + Math.random() * var[0] / 4;
90 s[j][1] = var[1] / 4 + Math.random() * var[1] / 4;
91
92 m[j][0] = (Max[0] + Min[0]) / 2 + (Math.random() - 0.5 ) *
93 (Max[0] - Min[0]) / 2;
94 m[j][1] = (Max[1] + Min[1]) / 2 + (Math.random() - 0.5 ) *
95 (Max[1] - Min[1]) / 2;
96 }
97 }
98
99 private void calcPrev() {
100 int i,j;
101 for(j=0;j<K;j++) {
102 w_prev[j] = w[j];
103 s_prev[j][0] = s[j][0];
104 s_prev[j][1] = s[j][1];
105 for(i=0;i<2;i++) {
106 m_prev[j][i] = m[j][i];
107 }
108 }
109 }
110
111 private double P_x_j(int n, int j) {
112 int i;
113 double e = 0.0;
114 for(i=0;i<2;i++) {
115 e += (points[n][i] - m[j][i]) *
116 (points[n][i] - m[j][i]);
117 }
118 e /= -2 * s[j][0] * s[j][1];
119 return Math.exp(e) / (2 * Math.PI * s[j][0] * s[j][1]);
120 }
121
122 private double P_j_x(int n, int j) {
123 return P_x_j(n,j) * w[j] / P_x[n];
124 }
125
126 private void P_x() {
127 int n,j;
128 for(n=0;n<N;n++) {
129 double temp = 0.0;
130 for(j=0;j<K;j++) {
131 temp += P_x_j(n,j) * w[j];
132 }
133 P_x[n] = temp;
134 }
135 }
136
137 private double P_x_j_prev(int n, int j) {
138 int i;
139 double e = 0.0;
140 for(i=0;i<2;i++)
141 e +=(points[n][i]-m_prev[j][i]) * (points[n][i]-m_prev[j][i]);
142 e /= -2 * s_prev[j][0] * s_prev[j][1];
143 return Math.exp(e)/(2*Math.PI * s_prev[j][0] * s_prev[j][1]);
144 }
145
146 // Calculates p(x,y) for given point
147 // p(x,y) = w1*p1(x,y) * ... * wK*pK(x,y)
148 private double pxy(double[] xy) {
149 int i,j;
150 double t = 0.0;
151 double e = 0.0;
152 for(j=0;j<K;j++) {
153 e = 0.0;
154 // exp{ -[(x-mxj )/sxj ]2/2 -[(y-myj )/syj ]2/2 }.
155 /*x*/ e +=(xy[0]-m[j][0])*(xy[0]-m[j][0])/s[j][0]/s[j][0];
156 /*y*/ e +=(xy[1]-m[j][1])*(xy[1]-m[j][1])/s[j][1]/s[j][1];
157
158 e /= -2;
159
160 // 1/(2 pi sxjsyj) * ...
161 e = Math.exp(e) / (2 * Math.PI * s[j][0] * s[j][1]);
162 t += e*w[j];
163 }
164 return t;
165 }
166
167 private double calcL() {
168 double subres = 0.0;
169 for(int i=0; i<N; i++) {
170 subres += -1.0 * Math.log( pxy(points[i]) );
171 }
172 L = subres;
173 return L;
174 }
175
176 private double P_j_x_prev(int n, int j) {
177 return P_x_j_prev(n,j) * w_prev[j] / P_x_prev[n];
178 }
179
180 private void P_x_prev() {
181 int n,j;
182 for(n=0;n<N;n++)
183 {
184 double temp = 0.0;
185 for(j=0;j<K;j++)
186 temp += P_x_j_prev(n,j) * w_prev[j];
187 P_x_prev[n] = temp;
188 }
189 }
190
191 private void doIteration() {
192 int i,j,n;
193 calcPrev();
194 P_x();
195 P_x_prev();
196
197 for(j=0;j<K;j++) {
198 double denom = 0.0;
199 for(n=0;n<N;n++) {
200 denom += P_j_x_prev(n,j);
201 }
202 for(i=0;i<2;i++)
203 {
204 m[j][i] = 0.0;
205 for(n=0;n<N;n++)
206 m[j][i] += P_j_x_prev(n,j) * points[n][i];
207 m[j][i] /= denom;
208 }
209
210 s[j][0] = 0.0;
211 s[j][1] = 0.0;
212 for(n=0;n<N;n++)
213 {
214 double sum;
215 // x
216 sum = (points[n][0]-m[j][0]) * (points[n][0]-m[j][0]);
217 s[j][0] += sum * P_j_x_prev(n,j);
218 // y
219 sum = (points[n][1]-m[j][1]) * (points[n][1]-m[j][1]);
220 s[j][1] += sum * P_j_x_prev(n,j);
221 }
222 s[j][0] /= denom;
223 s[j][1] /= denom;
224 s[j][0] = Math.sqrt(s[j][0]);
225 s[j][1] = Math.sqrt(s[j][1]);
226
227 w[j] = denom / (double)N;
228 }
229 }
230
231 private double howMuchChanged() {
232 double result = 0.0;
233 for(int j=0; j<K; j++) {
234 result += Math.abs(w[j] - w_prev[j]) +
235 Math.abs(m[j][0] - m_prev[j][0]) +
236 Math.abs(m[j][1] - m_prev[j][1]) +
237 Math.abs(s[j][0] - s_prev[j][0]) +
238 Math.abs(s[j][1] - s_prev[j][1]);
239
240 }
241 return result;
242 }
243
244 void printResult() {
245 int i,j;
246 System.out.println(L);
247 for(j=0;j<K;j++) {
248 System.out.println(w[j]+" "+m[j][0]+" "+s[j][0]+
249 " "+m[j][1]+" "+s[j][1] );
250 }
251 }
252
253 public static int [] getIntPair(String line) {
254 int sepIdx = line.indexOf(" ");
255 String a = line.substring(0, sepIdx);
256 String b = line.substring(sepIdx+1);
257 int [] KN = {Integer.parseInt(a), Integer.parseInt(b)};
258 return KN;
259 }
260
261 public static double [] getDoublePair(String line) {
262 int sepIdx = line.indexOf(" ");
263 String a = line.substring(0, sepIdx);
264 String b = line.substring(sepIdx+1);
265 double [] XY = {Double.parseDouble(a), Double.parseDouble(b)};
266 return XY;
267 }
268
269 public static void main(String args[]) throws Exception {
270 double [][] inData;
271
272 BufferedReader input = new
273 BufferedReader(new InputStreamReader(System.in));
274 String line = input.readLine();
275 if (line == null)
276 return;
277
278 int [] KN = getIntPair(line);
279
280 // Read (x,y) pairs
281 inData = new double[KN[1]][2];
282 for(int i=0; i< KN[1]; i++) {
283 line = input.readLine();
284 double [] XY = getDoublePair(line);
285 inData[i][0] = XY[0];
286 inData[i][1] = XY[1];
287 }
288
289 EM em = new EM(inData, KN[0]);
290 int its=100000000;
291 while(its-- > 0) {
292 em.doIteration();
293 // Stop when covergence
294 if(em.howMuchChanged() < 0.00000001)
295 break;
296 }
297 em.calcL();
298 em.printResult();
299 }
300 }

1 comments:

mkarabulut said...

Hello,

Very nice code...

I guess this code is one of the few implementations of EM algorithm for GMM that is published freely.

I've utilized this code to write my own EM implementation in C#.

I've got a question about GMMs. This code with the given PDF works well for two-dimensional case. But what if we use it for higher dimensional vectors. For instance vectors with 8 or 9 components(dimensions).

I've tried but it failed.

Have you got an idea about this issue?

Thank you.