Comparing version
exports.NeuralNetwork = require("./neuralnetwork").NeuralNetwork; | ||
exports.crossValidate = require("./cross-validate").crossValidate; | ||
exports.crossValidate = require("./cross-validate"); |
var _ = require("underscore")._; | ||
function testPartition(classifierConst, options, trainSet, testSet) { | ||
var classifier = new classifierConst(options); | ||
function testPartition(classifierConst, opts, trainOpts, trainSet, testSet) { | ||
var classifier = new classifierConst(opts); | ||
var beginTrain = Date.now(); | ||
var trainingStats = classifier.train(trainSet); | ||
var trainingStats = classifier.train(trainSet, trainOpts); | ||
@@ -16,17 +16,18 @@ var beginTest = Date.now(); | ||
return { | ||
error : testStats.error, | ||
misclasses: testStats.misclasses, | ||
var stats = _(testStats).extend({ | ||
trainTime : beginTest - beginTrain, | ||
testTime : endTest - beginTest, | ||
iterations: trainingStats.iterations, | ||
trainError: trainingStats.error | ||
}; | ||
trainError: trainingStats.error, | ||
learningRate: classifier.learningRate, | ||
hidden: classifier.hiddenSizes | ||
}); | ||
return stats; | ||
} | ||
module.exports = function crossValidate(classifierConst, options, data, k) { | ||
module.exports = function crossValidate(classifierConst, data, opts, trainOpts, k) { | ||
k = k || 4; | ||
var size = data.length / k; | ||
data = _(data).sortBy(function(num){ | ||
data = _(data).sortBy(function() { | ||
return Math.random(); | ||
@@ -40,3 +41,6 @@ }); | ||
iterations: 0, | ||
trainError: 0 | ||
trainError: 0, | ||
precision: 0, | ||
accuracy: 0, | ||
recall: 0 | ||
}; | ||
@@ -51,3 +55,3 @@ | ||
var result = testPartition(classifierConst, options, trainSet, testSet); | ||
var result = testPartition(classifierConst, opts, trainOpts, trainSet, testSet); | ||
@@ -72,2 +76,2 @@ _(avgs).each(function(sum, i) { | ||
}; | ||
} | ||
} |
@@ -76,7 +76,13 @@ var _ = require("underscore"), | ||
train: function(data, errorThresh, iterations) { | ||
train: function(data, options) { | ||
data = this.formatData(data); | ||
iterations = iterations || 20000; | ||
errorThresh = errorThresh || 0.004; | ||
options = options || {}; | ||
var iterations = options.iterations || 20000; | ||
var errorThresh = options.errorThresh || 0.005; | ||
var log = options.log || false; | ||
var logPeriod = options.logPeriod || 10; | ||
var callback = options.callback; | ||
var callbackPeriod = options.callbackPeriod || 10; | ||
var inputSize = data[0].input.length; | ||
@@ -87,3 +93,3 @@ var outputSize = data[0].output.length; | ||
if (!hiddenSizes) { | ||
hiddenSizes = [Math.max(3, inputSize)]; | ||
hiddenSizes = [Math.max(3, Math.floor(inputSize / 2))]; | ||
} | ||
@@ -97,5 +103,13 @@ var sizes = _([inputSize, hiddenSizes, outputSize]).flatten(); | ||
for (var j = 0; j < data.length; j++) { | ||
sum += this.trainPattern(data[j].input, data[j].output); | ||
var err = this.trainPattern(data[j].input, data[j].output); | ||
sum += err; | ||
} | ||
error = sum / data.length; | ||
if (log && (i % logPeriod == 0)) { | ||
console.log("iterations:", i, "training error:", error); | ||
} | ||
if (callback && (i % callbackPeriod == 0)) { | ||
callback({ error: error, iterations: i }); | ||
} | ||
} | ||
@@ -171,3 +185,3 @@ | ||
var array = lookup.toArray(this.inputLookup, datum.input) | ||
return {input: array, output: datum.output}; | ||
return _(_(datum).clone()).extend({ input: array }); | ||
}, this); | ||
@@ -182,3 +196,3 @@ } | ||
var array = lookup.toArray(this.outputLookup, datum.output); | ||
return {input: datum.input, output: array}; | ||
return _(_(datum).clone()).extend({ output: array }); | ||
}, this); | ||
@@ -189,26 +203,59 @@ } | ||
test : function(data) { | ||
test : function(data, binaryThresh) { | ||
data = this.formatData(data); | ||
binaryThresh = binaryThresh || 0.5; | ||
// for binary classification problems with one output node | ||
var isBinary = data[0].output.length == 1; | ||
var falsePos = 0, | ||
falseNeg = 0, | ||
truePos = 0, | ||
trueNeg = 0; | ||
// for classification problems | ||
var misclasses = []; | ||
// run each pattern through the trained network and collect | ||
// error and misclassification statistics | ||
var sum = 0; | ||
var misclasses = []; | ||
for (var i = 0; i < data.length; i++) { | ||
var output = this.runInput(data[i].input); | ||
var expected = data[i].output; | ||
var target = data[i].output; | ||
var actualClass = output.indexOf(_(output).max()); | ||
var expectedClass = expected.indexOf(_(expected).max()); | ||
var actual, expected; | ||
if (isBinary) { | ||
actual = output[0] > binaryThresh ? 1 : 0; | ||
expected = target[0]; | ||
} | ||
else { | ||
actual = output.indexOf(_(output).max()); | ||
expected = target.indexOf(_(target).max()); | ||
} | ||
if (actualClass != expectedClass) { | ||
misclasses.push({ | ||
input: data[i].input, | ||
actual: actualClass, | ||
expected: expectedClass | ||
}); | ||
if (actual != expected) { | ||
var misclass = data[i]; | ||
_(misclass).extend({ | ||
actual: actual, | ||
expected: expected | ||
}) | ||
misclasses.push(misclass); | ||
} | ||
if (isBinary) { | ||
if (actual == 0 && expected == 0) { | ||
trueNeg++; | ||
} | ||
else if (actual == 1 && expected == 1) { | ||
truePos++; | ||
} | ||
else if (actual == 0 && expected == 1) { | ||
falseNeg++; | ||
} | ||
else if (actual == 1 && expected == 0) { | ||
falsePos++; | ||
} | ||
} | ||
var errors = output.map(function(value, i) { | ||
return expected[i] - value; | ||
return target[i] - value; | ||
}); | ||
@@ -219,6 +266,20 @@ sum += mse(errors); | ||
return { | ||
var stats = { | ||
error: error, | ||
misclasses: misclasses | ||
}; | ||
if (isBinary) { | ||
_(stats).extend({ | ||
trueNeg: trueNeg, | ||
truePos: truePos, | ||
falseNeg: falseNeg, | ||
falsePos: falsePos, | ||
total: data.length, | ||
precision: truePos / (truePos + falsePos), | ||
recall: truePos / (truePos + falseNeg), | ||
accuracy: (trueNeg + truePos) / data.length | ||
}) | ||
} | ||
return stats; | ||
}, | ||
@@ -225,0 +286,0 @@ |
{ | ||
"name": "brain", | ||
"description": "Neural network library", | ||
"version": "0.5.0", | ||
"version": "0.6.0", | ||
"author": "Heather Arthur <fayearthur@gmail.com>", | ||
@@ -6,0 +6,0 @@ "repository": { |
@@ -42,7 +42,18 @@ # brain | ||
#### Threshold | ||
The optional second argument to `train()` is the error threshold (default `0.004`), the third is the maximum training iterations (default `20000`). | ||
#### Options | ||
`train()` takes a hash of options as its second argument: | ||
The network will train until the training error has gone below the threshold or the max number of iterations has been reached, whichever comes first. | ||
```javascript | ||
net.train(data, { | ||
errorThresh: 0.004, // error threshold to reach | ||
iterations: 20000, // maximum training iterations | ||
log: true, // console.log() progress periodically | ||
logPeriod: 10 // number of iterations between logging | ||
}) | ||
``` | ||
The network will train until the training error has gone below the threshold (default `0.004`) or the max number of iterations (default `20000`) has been reached, whichever comes first. | ||
By default training won't let you know how its doing until the end, but set `log` to `true` to get periodic updates on the current training error of the network. The training error should decrease every time. | ||
#### Output | ||
@@ -102,3 +113,3 @@ The ouput of `train()` is a hash of information about how the training went: | ||
#### learningRate | ||
The learning rate is a parameter that influences how quickly the network trains. It's a number from `0` to `1`. If the learning rate is close to `0` it will take a lot longer to train. If the learning rate is closer to `1` it will train faster but it's in danger of training to a local minimum and performing badly on new data. The default learning rate is `0.3`. | ||
The learning rate is a parameter that influences how quickly the network trains. It's a number from `0` to `1`. If the learning rate is close to `0` it will take longer to train. If the learning rate is closer to `1` it will train faster but it's in danger of training to a local minimum and performing badly on new data. The default learning rate is `0.3`. | ||
@@ -105,0 +116,0 @@ |
@@ -8,3 +8,3 @@ var assert = require('should'), | ||
var net = new brain.NeuralNetwork(); | ||
net.train(data); | ||
net.train(data, { errorThresh: 0.003 }); | ||
@@ -11,0 +11,0 @@ for(var i in data) { |
@@ -6,3 +6,3 @@ var assert = require('should'), | ||
describe('neural network options', function() { | ||
it('hidden', function() { | ||
it('hiddenLayers', function() { | ||
var net = new brain.NeuralNetwork({ hiddenLayers: [8, 7] }); | ||
@@ -22,2 +22,17 @@ | ||
it('hiddenLayers default expand to input size', function() { | ||
var net = new brain.NeuralNetwork(); | ||
net.train([{input: [0, 0, 1, 1, 1, 1, 1, 1, 1], output: [0]}, | ||
{input: [0, 1, 1, 1, 1, 1, 1, 1, 1], output: [1]}, | ||
{input: [1, 0, 1, 1, 1, 1, 1, 1, 1], output: [1]}, | ||
{input: [1, 1, 1, 1, 1, 1, 1, 1, 1], output: [0]}]); | ||
var json = net.toJSON(); | ||
assert.equal(json.layers.length, 3); | ||
assert.equal(_(json.layers[1]).keys().length, 4, "9 input units means 4 hidden"); | ||
}) | ||
it('learningRate - higher learning rate should train faster', function() { | ||
@@ -24,0 +39,0 @@ var data = [{input: [0, 0], output: [0]}, |
Major refactor
Supply chain riskPackage has recently undergone a major refactor. It may be unstable or indicate significant internal changes. Use caution when updating to versions that include significant changes.
Found 1 instance in 1 package
Uses eval
Supply chain riskPackage uses dynamic code execution (e.g., eval()), which is a dangerous practice. This can prevent the code from running in certain environments and increases the risk that the code may contain exploits or malicious behavior.
Found 1 instance in 1 package
Dynamic require
Supply chain riskDynamic require can indicate the package is performing dangerous or unsafe dynamic code execution.
Found 1 instance in 1 package
Major refactor
Supply chain riskPackage has recently undergone a major refactor. It may be unstable or indicate significant internal changes. Use caution when updating to versions that include significant changes.
Found 1 instance in 1 package
182720
337.22%27
42.11%3527
238.16%118
10.28%11
266.67%6
200%