gaussianMixture
Advanced tools
Comparing version 0.8.0 to 0.9.0
78
index.js
@@ -338,3 +338,2 @@ 'use strict'; | ||
* The initialization is agnostic to the other priors that the options might contain. | ||
* The `initialize` flag is unavailable with the histogram version of this function | ||
* @param {(Array|Histogram)} data the data array or histogram | ||
@@ -370,3 +369,3 @@ * @param {Number} [maxIterations=200] maximum number of expectation-maximization steps | ||
GMM.prototype._optimize = function (data, maxIterations, logLikelihoodTol) { | ||
if (this.options.initialize) this.initialize(data); | ||
if (this.options.initialize) this._initialize(data); | ||
@@ -391,3 +390,2 @@ maxIterations = maxIterations === undefined ? MAX_ITERATIONS : maxIterations; | ||
* Compute the optimal GMM components given a histogram of data. | ||
* K-means++ initialization is not implemented for the histogram version of this function. | ||
* @param {Histogram} h histogram of data used to optimize the model | ||
@@ -405,2 +403,4 @@ * @param {Number} [maxIterations=200] maximum number of expectation-maximization steps | ||
GMM.prototype._optimizeHistogram = function (h, maxIterations, logLikelihoodTol) { | ||
if (this.options.initialize) this._initializeHistogram(h); | ||
maxIterations = maxIterations === undefined ? MAX_ITERATIONS : maxIterations; | ||
@@ -421,3 +421,3 @@ logLikelihoodTol = logLikelihoodTol === undefined ? EPSILON : logLikelihoodTol; | ||
/** | ||
/** @private | ||
* Initialize the GMM given data with the [K-means++](https://en.wikipedia.org/wiki/K-means%2B%2B) initialization algorithm. | ||
@@ -433,3 +433,3 @@ * The k-means++ algorithm choses datapoints amongst the data at random, while ensuring that the chosen seeds are far from each other. | ||
*/ | ||
GMM.prototype.initialize = function (data) { | ||
GMM.prototype._initialize = function (data) { | ||
var n = data.length; | ||
@@ -479,2 +479,70 @@ | ||
/** @private | ||
* Initialize the GMM given data with the [K-means++](https://en.wikipedia.org/wiki/K-means%2B%2B) initialization algorithm. | ||
* The k-means++ algorithm choses datapoints amongst the data at random, while ensuring that the chosen seeds are far from each other. | ||
* The resulting seeds are returned sorted. | ||
* @param {Array} data array of numbers representing the samples to use to optimize the model | ||
* @return {Array} an array of length nComponents that contains the means for the initialization. | ||
* @example | ||
var gmm = new GMM(3, [0.3, .04, 0.3], [1, 5, 10]); | ||
var data = [1.2, 1.3, 7.4, 1.4, 14.3, 15.3, 1.0, 7.2]; | ||
gmm.initialize(data); // updates the means of the GMM with the K-means++ initialization algorithm, returns something like [1.3, 7.4, 14.3] | ||
*/ | ||
GMM.prototype._initializeHistogram = function (h) { | ||
var n = h.total; | ||
if (n < this.nComponents) throw new Error('Data must have more points than the number of components in the model.'); | ||
var keys = Object.keys(h.counts); | ||
var means = []; | ||
// Find the first seed at random | ||
var r = Math.random(); | ||
for (let i = 0; i < keys.length; i++) { | ||
let k = keys[i]; | ||
let p = (h.counts[k] / n) || 0; | ||
if (p > r || i === (keys.length - 1)) { | ||
means.push(h.value(k)); | ||
break; | ||
} else { | ||
r -= p; | ||
} | ||
} | ||
var distances = []; | ||
// Chose all other seeds | ||
for (let m = 1; m < this.nComponents; m++) { | ||
// Compute the distance from each datapoint | ||
var dsum = 0; | ||
for (let i = 0; i < keys.length; i++) { | ||
let k = keys[i]; | ||
var meansDistances = means.map(function (x) { return (x - h.value(k)) * (x - h.value(k)); }); | ||
var d = meansDistances.reduce(function (a, b) { return Math.min(a, b); }); | ||
distances[i] = d * h.counts[k]; | ||
dsum += d; | ||
} | ||
// Chose the next seed at random with probabilities d / dsum | ||
let r = Math.random(); | ||
for (let i = 0; i < keys.length; i++) { | ||
let k = keys[i]; | ||
let p = (distances[i] / dsum) || 0; | ||
if (p > r || i === (keys.length - 1)) { | ||
means.push(h.value(k)); | ||
break; | ||
} else { | ||
r -= p; | ||
} | ||
} | ||
} | ||
means.sort(function (a, b) { return a - b; }); | ||
this.means = means; | ||
return means; | ||
}; | ||
/** @private | ||
* Compute the barycenter given an array and weights. | ||
@@ -481,0 +549,0 @@ * @param {Array} array the array of values to find the barycenter from |
{ | ||
"name": "gaussianMixture", | ||
"version": "0.8.0", | ||
"version": "0.9.0", | ||
"description": "An implementation of a Gaussian Mixture class in one dimension, that allows to fit models with an Expectation Maximization algorithm.", | ||
@@ -5,0 +5,0 @@ "main": "index.js", |
@@ -196,9 +196,9 @@ 'use strict'; | ||
var means = gmm.initialize([1, 3, 3, 3, 2, 2, 1, 1, 3, 2, 2, 1, 3, 3, 3, 2, 1]); | ||
var means = gmm._initialize([1, 3, 3, 3, 2, 2, 1, 1, 3, 2, 2, 1, 3, 3, 3, 2, 1]); | ||
t.same(means, [1, 2, 3]); | ||
t.same(gmm.initialize([1, 1, 1, 1]), [1, 1, 1]); | ||
t.same(gmm.initialize([1, 1, 1, 2, 17]), [1, 2, 17]); | ||
t.same(gmm._initialize([1, 1, 1, 1]), [1, 1, 1]); | ||
t.same(gmm._initialize([1, 1, 1, 2, 17]), [1, 2, 17]); | ||
t.throws(function () { gmm.initialize([1]); }, new Error('Data must have more points than the number of components in the model.')); | ||
t.throws(function () { gmm._initialize([1]); }, new Error('Data must have more points than the number of components in the model.')); | ||
@@ -208,2 +208,23 @@ t.end(); | ||
test('Km++ Initialization - Histogram', function (t) { | ||
var gmm = new GMM(3, [0.4, 0.2, 0.4], [-1, 13, 25], [1, 2, 1]); | ||
var h = Histogram.fromData([1, 3, 3, 3, 2, 2, 1, 1, 3, 2, 2, 1, 3, 3, 3, 2, 1]); | ||
var means = gmm._initializeHistogram(h); | ||
t.same(means, [1, 2, 3]); | ||
t.same(gmm._initializeHistogram(Histogram.fromData([1, 1, 1, 1])), [1, 1, 1]); | ||
t.same(gmm._initializeHistogram(Histogram.fromData([1, 1, 1, 2, 17])), [1, 2, 17]); | ||
t.throws(function () { gmm._initializeHistogram(Histogram.fromData([1])); }, new Error('Data must have more points than the number of components in the model.')); | ||
h = new Histogram({ | ||
counts: {'A': 10000, 'B': 0.001, 'C': 10, 'D': 10}, | ||
bins: {'A': [0, 1], 'B': [1, 2], 'C': [3, 4], 'D': [4, 5]} | ||
}); | ||
t.same(gmm._initializeHistogram(h), [0.5, 3.5, 4.5]); | ||
t.end(); | ||
}); | ||
test('memberships - histogram', function (t) { | ||
@@ -267,3 +288,4 @@ var h = Histogram.fromData([1, 2, 5, 5.4, 5.5, 6, 7, 7]); | ||
separationPrior: 3, | ||
separationPriorRelevance: 1 | ||
separationPriorRelevance: 1, | ||
initialize: true | ||
}; | ||
@@ -270,0 +292,0 @@ |
53771
907