ml-cross-validation
Advanced tools
Comparing version
@@ -0,1 +1,11 @@ | ||
<a name="1.2.0"></a> | ||
# [1.2.0](https://github.com/mljs/cross-validation/compare/v1.1.0...v1.2.0) (2017-11-08) | ||
### Features | ||
* allow passing a callback to cross-validation methods ([32501e2](https://github.com/mljs/cross-validation/commit/32501e2)) | ||
<a name="1.1.0"></a> | ||
@@ -2,0 +12,0 @@ # [1.1.0](https://github.com/mljs/cross-validation/compare/v1.0.1...v1.1.0) (2017-07-07) |
{ | ||
"name": "ml-cross-validation", | ||
"version": "1.1.0", | ||
"version": "1.2.0", | ||
"description": "Cross validation utility for mljs classifiers", | ||
@@ -33,3 +33,3 @@ "main": "src", | ||
"eslint-plugin-no-only-tests": "^2.0.0", | ||
"mocha": "^3.1.2", | ||
"mocha": "^3.5.3", | ||
"mocha-better-spec-reporter": "^3.0.2", | ||
@@ -36,0 +36,0 @@ "should": "^11.1.1" |
@@ -7,9 +7,66 @@ # cross-validation | ||
Utility library to do cross validation with mljs classifiers | ||
Utility library to do cross validation with supervised classifiers. | ||
A list of the mljs supervised classifiers is available [here](https://github.com/mljs/ml#tools) in the supervised learning section | ||
Cross-validation methods: | ||
- [k-fold](https://en.wikipedia.org/wiki/Cross-validation_(statistics)#k-fold_cross-validation) | ||
- [leave-p-out](https://en.wikipedia.org/wiki/Cross-validation_(statistics)#Leave-p-out_cross-validation) | ||
[Documentation](https://mljs.github.io/cross-validation/) | ||
[API documentation](https://mljs.github.io/cross-validation/). | ||
A list of the mljs supervised classifiers is available [here](https://github.com/mljs/ml#tools) in the supervised learning section, but you could also use your own. Cross validations methods return a ConfusionMatrix ([https://github.com/mljs/confusion-matrix](https://github.com/mljs/confusion-matrix)) that can be used to calculate metrics on your classification result. | ||
## Installation | ||
```bash | ||
npm i -s ml-cross-validation | ||
``` | ||
## Example using a ml classification library | ||
```js | ||
const crossValidation = require('ml-cross-validation'); | ||
const KNN = require('ml-knn'); | ||
const dataset = [[0, 0, 0], [0, 1, 1], [1, 1, 0], [2, 2, 2], [1, 2, 2], [2, 1, 2]]; | ||
const labels = [0, 0, 0, 1, 1, 1]; | ||
const confusionMatrix = crossValidation.leaveOneOut(KNN, dataSet, labels); | ||
const accuracy = confusionMatrix.getAccuracy(); | ||
``` | ||
## Example using a classifier with its own specific API | ||
If you have a library that does not comply with the ML Classifier conventions, you can use can use a callback to perform the classification. | ||
The callback will take the train features and labels, and the test features. The callback shoud return the array of predicted labels. | ||
```js | ||
const crossValidation = require('ml-cross-validation'); | ||
const KNN = require('ml-knn'); | ||
const dataset = [[0, 0, 0], [0, 1, 1], [1, 1, 0], [2, 2, 2], [1, 2, 2], [2, 1, 2]]; | ||
const labels = [0, 0, 0, 1, 1, 1]; | ||
const confusionMatrix = crossValidation.leaveOneOut(dataSet, labels, function(trainFeatures, trainLabels, testFeatures) { | ||
const knn = new KNN(trainFeatures, trainLabels); | ||
return knn.predict(testFeatures); | ||
}); | ||
const accuracy = confusionMatrix.getAccuracy(); | ||
``` | ||
## ML classifier API conventions | ||
You can write your classification library so that it can be used with ml-cross-validation as described in [here](#example-using-a-ml-classification-library) | ||
For that, your classification library must implement | ||
- A constructor. The constructor can be passed options as a single argument. | ||
- A `train` method. The `train` method is passed the data as a first argument and the labels as a second. | ||
- A `predict` method. The `predict` method is passed test data and should return a predicted label. | ||
### Example | ||
```js | ||
class MyClassifier { | ||
constructor(options) { | ||
this.options = options; | ||
} | ||
train(data, labels) { | ||
// Create your model | ||
} | ||
predict(testData) { | ||
// Apply your model and return predicted label | ||
return prediction; | ||
} | ||
} | ||
``` | ||
### | ||
[npm-image]: https://img.shields.io/npm/v/ml-cross-validation.svg?style=flat-square | ||
@@ -16,0 +73,0 @@ [npm-url]: https://npmjs.org/package/ml-cross-validation |
@@ -9,6 +9,7 @@ 'use strict'; | ||
/** | ||
* Performs a leave-one-out cross-validation (LOO-CV) of the given samples. In LOO-CV, 1 observation is used as the validation | ||
* set while the rest is used as the training set. This is repeated once for each observation. LOO-CV is a special case | ||
* of LPO-CV. @see leavePout | ||
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier api. | ||
* Performs a leave-one-out cross-validation (LOO-CV) of the given samples. In LOO-CV, 1 observation is used as the | ||
* validation set while the rest is used as the training set. This is repeated once for each observation. LOO-CV is a | ||
* special case of LPO-CV. @see leavePout | ||
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier | ||
* api. | ||
* @param {Array} features - The features for all samples of the data-set | ||
@@ -20,2 +21,8 @@ * @param {Array} labels - The classification class of all samples of the data-set | ||
CV.leaveOneOut = function (Classifier, features, labels, classifierOptions) { | ||
if (typeof labels === 'function') { | ||
var callback = labels; | ||
labels = features; | ||
features = Classifier; | ||
return CV.leavePOut(features, labels, 1, callback); | ||
} | ||
return CV.leavePOut(Classifier, features, labels, classifierOptions, 1); | ||
@@ -30,3 +37,4 @@ }; | ||
* data-set size this can require a very large number of training and testing to do! | ||
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier api. | ||
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier | ||
* api. | ||
* @param {Array} features - The features for all samples of the data-set | ||
@@ -39,2 +47,8 @@ * @param {Array} labels - The classification class of all samples of the data-set | ||
CV.leavePOut = function (Classifier, features, labels, classifierOptions, p) { | ||
if (typeof classifierOptions === 'function') { | ||
var callback = classifierOptions; | ||
p = labels; | ||
labels = features; | ||
features = Classifier; | ||
} | ||
check(features, labels); | ||
@@ -57,3 +71,8 @@ const distinct = getDistinct(labels); | ||
validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct); | ||
if (callback) { | ||
validateWithCallback(features, labels, testIdx, trainIdx, confusionMatrix, distinct, callback); | ||
} else { | ||
validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct); | ||
} | ||
} | ||
@@ -76,2 +95,8 @@ | ||
CV.kFold = function (Classifier, features, labels, classifierOptions, k) { | ||
if (typeof classifierOptions === 'function') { | ||
var callback = classifierOptions; | ||
k = labels; | ||
labels = features; | ||
features = Classifier; | ||
} | ||
check(features, labels); | ||
@@ -110,3 +135,7 @@ const distinct = getDistinct(labels); | ||
validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct); | ||
if (callback) { | ||
validateWithCallback(features, labels, testIdx, trainIdx, confusionMatrix, distinct, callback); | ||
} else { | ||
validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct); | ||
} | ||
} | ||
@@ -136,14 +165,3 @@ | ||
function validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct) { | ||
var testFeatures = testIdx.map(function (index) { | ||
return features[index]; | ||
}); | ||
var trainFeatures = trainIdx.map(function (index) { | ||
return features[index]; | ||
}); | ||
var testLabels = testIdx.map(function (index) { | ||
return labels[index]; | ||
}); | ||
var trainLabels = trainIdx.map(function (index) { | ||
return labels[index]; | ||
}); | ||
const {testFeatures, trainFeatures, testLabels, trainLabels} = getTrainTest(features, labels, testIdx, trainIdx); | ||
@@ -159,7 +177,42 @@ var classifier; | ||
var predictedLabels = classifier.predict(testFeatures); | ||
updateConfusionMatrix(confusionMatrix, testLabels, predictedLabels, distinct); | ||
} | ||
function validateWithCallback(features, labels, testIdx, trainIdx, confusionMatrix, distinct, callback) { | ||
const {testFeatures, trainFeatures, testLabels, trainLabels} = getTrainTest(features, labels, testIdx, trainIdx); | ||
const predictedLabels = callback(trainFeatures, trainLabels, testFeatures); | ||
updateConfusionMatrix(confusionMatrix, testLabels, predictedLabels, distinct); | ||
} | ||
function updateConfusionMatrix(confusionMatrix, testLabels, predictedLabels, distinct) { | ||
for (var i = 0; i < predictedLabels.length; i++) { | ||
confusionMatrix[distinct.indexOf(testLabels[i])][distinct.indexOf(predictedLabels[i])]++; | ||
const actualIdx = distinct.indexOf(testLabels[i]); | ||
const predictedIdx = distinct.indexOf(predictedLabels[i]); | ||
if (actualIdx < 0 || predictedIdx < 0) { | ||
// eslint-disable-next-line no-console | ||
console.warn(`ignore unknown predicted label ${predictedLabels[i]}`); | ||
} | ||
confusionMatrix[actualIdx][predictedIdx]++; | ||
} | ||
} | ||
function getTrainTest(features, labels, testIdx, trainIdx) { | ||
return { | ||
testFeatures: testIdx.map(function (index) { | ||
return features[index]; | ||
}), | ||
trainFeatures: trainIdx.map(function (index) { | ||
return features[index]; | ||
}), | ||
testLabels: testIdx.map(function (index) { | ||
return labels[index]; | ||
}), | ||
trainLabels: trainIdx.map(function (index) { | ||
return labels[index]; | ||
}) | ||
}; | ||
} | ||
module.exports = CV; |
@@ -6,9 +6,13 @@ 'use strict'; | ||
var LOO = require('./data/LOO-CV'); | ||
var LPO = require('./data/LPO-CV'); | ||
var KF = require('./data/KF-CV'); | ||
describe('basic', function () { | ||
it('basic leave-one-out cross-validation', function () { | ||
var LOO = require('./data/LOO-CV'); | ||
for (let i = 0; i < LOO.length; i++) { | ||
var CM = CV.leaveOneOut(Dummy, LOO[i].features, LOO[i].labels, LOO[i].classifierOptions); | ||
CM.matrix.should.deepEqual(LOO[i].result.matrix); | ||
CM.labels.should.deepEqual(LOO[i].result.labels); | ||
CM.getMatrix().should.deepEqual(LOO[i].result.matrix); | ||
CM.getLabels().should.deepEqual(LOO[i].result.labels); | ||
} | ||
@@ -18,7 +22,6 @@ }); | ||
it('basic leave-p-out cross-validation', function () { | ||
var LPO = require('./data/LPO-CV'); | ||
for (let i = 0; i < LPO.length; i++) { | ||
var CM = CV.leavePOut(Dummy, LPO[i].features, LPO[i].labels, LPO[i].classifierOptions, LPO[i].p); | ||
CM.matrix.should.deepEqual(LPO[i].result.matrix); | ||
CM.labels.should.deepEqual(LPO[i].result.labels); | ||
CM.getMatrix().should.deepEqual(LPO[i].result.matrix); | ||
CM.getLabels().should.deepEqual(LPO[i].result.labels); | ||
} | ||
@@ -28,9 +31,46 @@ }); | ||
it('basic k-fold cross-validation', function () { | ||
var KF = require('./data/KF-CV'); | ||
for (let i = 0; i < KF.length; i++) { | ||
var CM = CV.kFold(Dummy, KF[i].features, KF[i].labels, KF[i].classifierOptions, KF[i].k); | ||
CM.matrix.should.deepEqual(KF[i].result.matrix); | ||
CM.labels.should.deepEqual(KF[i].result.labels); | ||
CM.getMatrix().should.deepEqual(KF[i].result.matrix); | ||
CM.getLabels().should.deepEqual(KF[i].result.labels); | ||
} | ||
}); | ||
}); | ||
describe('with a callback', function () { | ||
it('basic leave-on-out cross-validation with callback', function () { | ||
for (let i = 0; i < LOO.length; i++) { | ||
var CM = CV.leaveOneOut(LOO[i].features, LOO[i].labels, function (trainFeatures, trainLabels, testFeatures) { | ||
const classifier = new Dummy(LOO[i].classifierOptions); | ||
classifier.train(trainFeatures, trainLabels); | ||
return classifier.predict(testFeatures); | ||
}); | ||
CM.getMatrix().should.deepEqual(LOO[i].result.matrix); | ||
CM.getLabels().should.deepEqual(LOO[i].result.labels); | ||
} | ||
}); | ||
it('basic leave-p-out cross-validation with callback', function () { | ||
for (let i = 0; i < LPO.length; i++) { | ||
var CM = CV.leavePOut(LPO[i].features, LPO[i].labels, LPO[i].p, function (trainFeatures, trainLabels, testFeatures) { | ||
const classifier = new Dummy(LPO[i].classifierOptions); | ||
classifier.train(trainFeatures, trainLabels); | ||
return classifier.predict(testFeatures); | ||
}); | ||
CM.getMatrix().should.deepEqual(LPO[i].result.matrix); | ||
CM.getLabels().should.deepEqual(LPO[i].result.labels); | ||
} | ||
}); | ||
it('basic k-fold cross-validation with callback', function () { | ||
for (let i = 0; i < KF.length; i++) { | ||
var CM = CV.kFold(KF[i].features, KF[i].labels, KF[i].k, function (trainFeatures, trainLabels, testFeatures) { | ||
const classifier = new Dummy(KF[i].classifierOptions); | ||
classifier.train(trainFeatures, trainLabels); | ||
return classifier.predict(testFeatures); | ||
}); | ||
CM.getMatrix().should.deepEqual(KF[i].result.matrix); | ||
CM.getLabels().should.deepEqual(KF[i].result.labels); | ||
} | ||
}); | ||
}); |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Major refactor
Supply chain riskPackage has recently undergone a major refactor. It may be unstable or indicate significant internal changes. Use caution when updating to versions that include significant changes.
Found 1 instance in 1 package
1608569
0.25%1276
6.87%77
285%0
-100%44
-2.22%