🚀 Big News: Socket Acquires Coana to Bring Reachability Analysis to Every Appsec Team.Learn more →

brain

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

brain - npm Package Compare versions

Comparing version

to
0.6.0

exports.NeuralNetwork = require("./neuralnetwork").NeuralNetwork;
exports.crossValidate = require("./cross-validate").crossValidate;
exports.crossValidate = require("./cross-validate");
var _ = require("underscore")._;
function testPartition(classifierConst, options, trainSet, testSet) {
var classifier = new classifierConst(options);
function testPartition(classifierConst, opts, trainOpts, trainSet, testSet) {
var classifier = new classifierConst(opts);
var beginTrain = Date.now();
var trainingStats = classifier.train(trainSet);
var trainingStats = classifier.train(trainSet, trainOpts);

@@ -16,17 +16,18 @@ var beginTest = Date.now();

return {
error : testStats.error,
misclasses: testStats.misclasses,
var stats = _(testStats).extend({
trainTime : beginTest - beginTrain,
testTime : endTest - beginTest,
iterations: trainingStats.iterations,
trainError: trainingStats.error
};
trainError: trainingStats.error,
learningRate: classifier.learningRate,
hidden: classifier.hiddenSizes
});
return stats;
}
module.exports = function crossValidate(classifierConst, options, data, k) {
module.exports = function crossValidate(classifierConst, data, opts, trainOpts, k) {
k = k || 4;
var size = data.length / k;
data = _(data).sortBy(function(num){
data = _(data).sortBy(function() {
return Math.random();

@@ -40,3 +41,6 @@ });

iterations: 0,
trainError: 0
trainError: 0,
precision: 0,
accuracy: 0,
recall: 0
};

@@ -51,3 +55,3 @@

var result = testPartition(classifierConst, options, trainSet, testSet);
var result = testPartition(classifierConst, opts, trainOpts, trainSet, testSet);

@@ -72,2 +76,2 @@ _(avgs).each(function(sum, i) {

};
}
}

@@ -76,7 +76,13 @@ var _ = require("underscore"),

train: function(data, errorThresh, iterations) {
train: function(data, options) {
data = this.formatData(data);
iterations = iterations || 20000;
errorThresh = errorThresh || 0.004;
options = options || {};
var iterations = options.iterations || 20000;
var errorThresh = options.errorThresh || 0.005;
var log = options.log || false;
var logPeriod = options.logPeriod || 10;
var callback = options.callback;
var callbackPeriod = options.callbackPeriod || 10;
var inputSize = data[0].input.length;

@@ -87,3 +93,3 @@ var outputSize = data[0].output.length;

if (!hiddenSizes) {
hiddenSizes = [Math.max(3, inputSize)];
hiddenSizes = [Math.max(3, Math.floor(inputSize / 2))];
}

@@ -97,5 +103,13 @@ var sizes = _([inputSize, hiddenSizes, outputSize]).flatten();

for (var j = 0; j < data.length; j++) {
sum += this.trainPattern(data[j].input, data[j].output);
var err = this.trainPattern(data[j].input, data[j].output);
sum += err;
}
error = sum / data.length;
if (log && (i % logPeriod == 0)) {
console.log("iterations:", i, "training error:", error);
}
if (callback && (i % callbackPeriod == 0)) {
callback({ error: error, iterations: i });
}
}

@@ -171,3 +185,3 @@

var array = lookup.toArray(this.inputLookup, datum.input)
return {input: array, output: datum.output};
return _(_(datum).clone()).extend({ input: array });
}, this);

@@ -182,3 +196,3 @@ }

var array = lookup.toArray(this.outputLookup, datum.output);
return {input: datum.input, output: array};
return _(_(datum).clone()).extend({ output: array });
}, this);

@@ -189,26 +203,59 @@ }

test : function(data) {
test : function(data, binaryThresh) {
data = this.formatData(data);
binaryThresh = binaryThresh || 0.5;
// for binary classification problems with one output node
var isBinary = data[0].output.length == 1;
var falsePos = 0,
falseNeg = 0,
truePos = 0,
trueNeg = 0;
// for classification problems
var misclasses = [];
// run each pattern through the trained network and collect
// error and misclassification statistics
var sum = 0;
var misclasses = [];
for (var i = 0; i < data.length; i++) {
var output = this.runInput(data[i].input);
var expected = data[i].output;
var target = data[i].output;
var actualClass = output.indexOf(_(output).max());
var expectedClass = expected.indexOf(_(expected).max());
var actual, expected;
if (isBinary) {
actual = output[0] > binaryThresh ? 1 : 0;
expected = target[0];
}
else {
actual = output.indexOf(_(output).max());
expected = target.indexOf(_(target).max());
}
if (actualClass != expectedClass) {
misclasses.push({
input: data[i].input,
actual: actualClass,
expected: expectedClass
});
if (actual != expected) {
var misclass = data[i];
_(misclass).extend({
actual: actual,
expected: expected
})
misclasses.push(misclass);
}
if (isBinary) {
if (actual == 0 && expected == 0) {
trueNeg++;
}
else if (actual == 1 && expected == 1) {
truePos++;
}
else if (actual == 0 && expected == 1) {
falseNeg++;
}
else if (actual == 1 && expected == 0) {
falsePos++;
}
}
var errors = output.map(function(value, i) {
return expected[i] - value;
return target[i] - value;
});

@@ -219,6 +266,20 @@ sum += mse(errors);

return {
var stats = {
error: error,
misclasses: misclasses
};
if (isBinary) {
_(stats).extend({
trueNeg: trueNeg,
truePos: truePos,
falseNeg: falseNeg,
falsePos: falsePos,
total: data.length,
precision: truePos / (truePos + falsePos),
recall: truePos / (truePos + falseNeg),
accuracy: (trueNeg + truePos) / data.length
})
}
return stats;
},

@@ -225,0 +286,0 @@

{
"name": "brain",
"description": "Neural network library",
"version": "0.5.0",
"version": "0.6.0",
"author": "Heather Arthur <fayearthur@gmail.com>",

@@ -6,0 +6,0 @@ "repository": {

@@ -42,7 +42,18 @@ # brain

#### Threshold
The optional second argument to `train()` is the error threshold (default `0.004`), the third is the maximum training iterations (default `20000`).
#### Options
`train()` takes a hash of options as its second argument:
The network will train until the training error has gone below the threshold or the max number of iterations has been reached, whichever comes first.
```javascript
net.train(data, {
errorThresh: 0.004, // error threshold to reach
iterations: 20000, // maximum training iterations
log: true, // console.log() progress periodically
logPeriod: 10 // number of iterations between logging
})
```
The network will train until the training error has gone below the threshold (default `0.004`) or the max number of iterations (default `20000`) has been reached, whichever comes first.
By default training won't let you know how its doing until the end, but set `log` to `true` to get periodic updates on the current training error of the network. The training error should decrease every time.
#### Output

@@ -102,3 +113,3 @@ The ouput of `train()` is a hash of information about how the training went:

#### learningRate
The learning rate is a parameter that influences how quickly the network trains. It's a number from `0` to `1`. If the learning rate is close to `0` it will take a lot longer to train. If the learning rate is closer to `1` it will train faster but it's in danger of training to a local minimum and performing badly on new data. The default learning rate is `0.3`.
The learning rate is a parameter that influences how quickly the network trains. It's a number from `0` to `1`. If the learning rate is close to `0` it will take longer to train. If the learning rate is closer to `1` it will train faster but it's in danger of training to a local minimum and performing badly on new data. The default learning rate is `0.3`.

@@ -105,0 +116,0 @@

@@ -8,3 +8,3 @@ var assert = require('should'),

var net = new brain.NeuralNetwork();
net.train(data);
net.train(data, { errorThresh: 0.003 });

@@ -11,0 +11,0 @@ for(var i in data) {

@@ -6,3 +6,3 @@ var assert = require('should'),

describe('neural network options', function() {
it('hidden', function() {
it('hiddenLayers', function() {
var net = new brain.NeuralNetwork({ hiddenLayers: [8, 7] });

@@ -22,2 +22,17 @@

it('hiddenLayers default expand to input size', function() {
var net = new brain.NeuralNetwork();
net.train([{input: [0, 0, 1, 1, 1, 1, 1, 1, 1], output: [0]},
{input: [0, 1, 1, 1, 1, 1, 1, 1, 1], output: [1]},
{input: [1, 0, 1, 1, 1, 1, 1, 1, 1], output: [1]},
{input: [1, 1, 1, 1, 1, 1, 1, 1, 1], output: [0]}]);
var json = net.toJSON();
assert.equal(json.layers.length, 3);
assert.equal(_(json.layers[1]).keys().length, 4, "9 input units means 4 hidden");
})
it('learningRate - higher learning rate should train faster', function() {

@@ -24,0 +39,0 @@ var data = [{input: [0, 0], output: [0]},