ml-cross-validation
Advanced tools
Comparing version 1.1.0 to 1.2.0
@@ -0,1 +1,11 @@ | ||
<a name="1.2.0"></a> | ||
# [1.2.0](https://github.com/mljs/cross-validation/compare/v1.1.0...v1.2.0) (2017-11-08) | ||
### Features | ||
* allow passing a callback to cross-validation methods ([32501e2](https://github.com/mljs/cross-validation/commit/32501e2)) | ||
<a name="1.1.0"></a> | ||
@@ -2,0 +12,0 @@ # [1.1.0](https://github.com/mljs/cross-validation/compare/v1.0.1...v1.1.0) (2017-07-07) |
{ | ||
"name": "ml-cross-validation", | ||
"version": "1.1.0", | ||
"version": "1.2.0", | ||
"description": "Cross validation utility for mljs classifiers", | ||
@@ -33,3 +33,3 @@ "main": "src", | ||
"eslint-plugin-no-only-tests": "^2.0.0", | ||
"mocha": "^3.1.2", | ||
"mocha": "^3.5.3", | ||
"mocha-better-spec-reporter": "^3.0.2", | ||
@@ -36,0 +36,0 @@ "should": "^11.1.1" |
@@ -7,9 +7,66 @@ # cross-validation | ||
Utility library to do cross validation with mljs classifiers | ||
Utility library to do cross validation with supervised classifiers. | ||
A list of the mljs supervised classifiers is available [here](https://github.com/mljs/ml#tools) in the supervised learning section | ||
Cross-validation methods: | ||
- [k-fold](https://en.wikipedia.org/wiki/Cross-validation_(statistics)#k-fold_cross-validation) | ||
- [leave-p-out](https://en.wikipedia.org/wiki/Cross-validation_(statistics)#Leave-p-out_cross-validation) | ||
[Documentation](https://mljs.github.io/cross-validation/) | ||
[API documentation](https://mljs.github.io/cross-validation/). | ||
A list of the mljs supervised classifiers is available [here](https://github.com/mljs/ml#tools) in the supervised learning section, but you could also use your own. Cross validations methods return a ConfusionMatrix ([https://github.com/mljs/confusion-matrix](https://github.com/mljs/confusion-matrix)) that can be used to calculate metrics on your classification result. | ||
## Installation | ||
```bash | ||
npm i -s ml-cross-validation | ||
``` | ||
## Example using a ml classification library | ||
```js | ||
const crossValidation = require('ml-cross-validation'); | ||
const KNN = require('ml-knn'); | ||
const dataset = [[0, 0, 0], [0, 1, 1], [1, 1, 0], [2, 2, 2], [1, 2, 2], [2, 1, 2]]; | ||
const labels = [0, 0, 0, 1, 1, 1]; | ||
const confusionMatrix = crossValidation.leaveOneOut(KNN, dataSet, labels); | ||
const accuracy = confusionMatrix.getAccuracy(); | ||
``` | ||
## Example using a classifier with its own specific API | ||
If you have a library that does not comply with the ML Classifier conventions, you can use can use a callback to perform the classification. | ||
The callback will take the train features and labels, and the test features. The callback shoud return the array of predicted labels. | ||
```js | ||
const crossValidation = require('ml-cross-validation'); | ||
const KNN = require('ml-knn'); | ||
const dataset = [[0, 0, 0], [0, 1, 1], [1, 1, 0], [2, 2, 2], [1, 2, 2], [2, 1, 2]]; | ||
const labels = [0, 0, 0, 1, 1, 1]; | ||
const confusionMatrix = crossValidation.leaveOneOut(dataSet, labels, function(trainFeatures, trainLabels, testFeatures) { | ||
const knn = new KNN(trainFeatures, trainLabels); | ||
return knn.predict(testFeatures); | ||
}); | ||
const accuracy = confusionMatrix.getAccuracy(); | ||
``` | ||
## ML classifier API conventions | ||
You can write your classification library so that it can be used with ml-cross-validation as described in [here](#example-using-a-ml-classification-library) | ||
For that, your classification library must implement | ||
- A constructor. The constructor can be passed options as a single argument. | ||
- A `train` method. The `train` method is passed the data as a first argument and the labels as a second. | ||
- A `predict` method. The `predict` method is passed test data and should return a predicted label. | ||
### Example | ||
```js | ||
class MyClassifier { | ||
constructor(options) { | ||
this.options = options; | ||
} | ||
train(data, labels) { | ||
// Create your model | ||
} | ||
predict(testData) { | ||
// Apply your model and return predicted label | ||
return prediction; | ||
} | ||
} | ||
``` | ||
### | ||
[npm-image]: https://img.shields.io/npm/v/ml-cross-validation.svg?style=flat-square | ||
@@ -16,0 +73,0 @@ [npm-url]: https://npmjs.org/package/ml-cross-validation |
@@ -9,6 +9,7 @@ 'use strict'; | ||
/** | ||
* Performs a leave-one-out cross-validation (LOO-CV) of the given samples. In LOO-CV, 1 observation is used as the validation | ||
* set while the rest is used as the training set. This is repeated once for each observation. LOO-CV is a special case | ||
* of LPO-CV. @see leavePout | ||
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier api. | ||
* Performs a leave-one-out cross-validation (LOO-CV) of the given samples. In LOO-CV, 1 observation is used as the | ||
* validation set while the rest is used as the training set. This is repeated once for each observation. LOO-CV is a | ||
* special case of LPO-CV. @see leavePout | ||
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier | ||
* api. | ||
* @param {Array} features - The features for all samples of the data-set | ||
@@ -20,2 +21,8 @@ * @param {Array} labels - The classification class of all samples of the data-set | ||
CV.leaveOneOut = function (Classifier, features, labels, classifierOptions) { | ||
if (typeof labels === 'function') { | ||
var callback = labels; | ||
labels = features; | ||
features = Classifier; | ||
return CV.leavePOut(features, labels, 1, callback); | ||
} | ||
return CV.leavePOut(Classifier, features, labels, classifierOptions, 1); | ||
@@ -30,3 +37,4 @@ }; | ||
* data-set size this can require a very large number of training and testing to do! | ||
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier api. | ||
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier | ||
* api. | ||
* @param {Array} features - The features for all samples of the data-set | ||
@@ -39,2 +47,8 @@ * @param {Array} labels - The classification class of all samples of the data-set | ||
CV.leavePOut = function (Classifier, features, labels, classifierOptions, p) { | ||
if (typeof classifierOptions === 'function') { | ||
var callback = classifierOptions; | ||
p = labels; | ||
labels = features; | ||
features = Classifier; | ||
} | ||
check(features, labels); | ||
@@ -57,3 +71,8 @@ const distinct = getDistinct(labels); | ||
validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct); | ||
if (callback) { | ||
validateWithCallback(features, labels, testIdx, trainIdx, confusionMatrix, distinct, callback); | ||
} else { | ||
validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct); | ||
} | ||
} | ||
@@ -76,2 +95,8 @@ | ||
CV.kFold = function (Classifier, features, labels, classifierOptions, k) { | ||
if (typeof classifierOptions === 'function') { | ||
var callback = classifierOptions; | ||
k = labels; | ||
labels = features; | ||
features = Classifier; | ||
} | ||
check(features, labels); | ||
@@ -110,3 +135,7 @@ const distinct = getDistinct(labels); | ||
validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct); | ||
if (callback) { | ||
validateWithCallback(features, labels, testIdx, trainIdx, confusionMatrix, distinct, callback); | ||
} else { | ||
validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct); | ||
} | ||
} | ||
@@ -136,14 +165,3 @@ | ||
function validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct) { | ||
var testFeatures = testIdx.map(function (index) { | ||
return features[index]; | ||
}); | ||
var trainFeatures = trainIdx.map(function (index) { | ||
return features[index]; | ||
}); | ||
var testLabels = testIdx.map(function (index) { | ||
return labels[index]; | ||
}); | ||
var trainLabels = trainIdx.map(function (index) { | ||
return labels[index]; | ||
}); | ||
const {testFeatures, trainFeatures, testLabels, trainLabels} = getTrainTest(features, labels, testIdx, trainIdx); | ||
@@ -159,7 +177,42 @@ var classifier; | ||
var predictedLabels = classifier.predict(testFeatures); | ||
updateConfusionMatrix(confusionMatrix, testLabels, predictedLabels, distinct); | ||
} | ||
function validateWithCallback(features, labels, testIdx, trainIdx, confusionMatrix, distinct, callback) { | ||
const {testFeatures, trainFeatures, testLabels, trainLabels} = getTrainTest(features, labels, testIdx, trainIdx); | ||
const predictedLabels = callback(trainFeatures, trainLabels, testFeatures); | ||
updateConfusionMatrix(confusionMatrix, testLabels, predictedLabels, distinct); | ||
} | ||
function updateConfusionMatrix(confusionMatrix, testLabels, predictedLabels, distinct) { | ||
for (var i = 0; i < predictedLabels.length; i++) { | ||
confusionMatrix[distinct.indexOf(testLabels[i])][distinct.indexOf(predictedLabels[i])]++; | ||
const actualIdx = distinct.indexOf(testLabels[i]); | ||
const predictedIdx = distinct.indexOf(predictedLabels[i]); | ||
if (actualIdx < 0 || predictedIdx < 0) { | ||
// eslint-disable-next-line no-console | ||
console.warn(`ignore unknown predicted label ${predictedLabels[i]}`); | ||
} | ||
confusionMatrix[actualIdx][predictedIdx]++; | ||
} | ||
} | ||
function getTrainTest(features, labels, testIdx, trainIdx) { | ||
return { | ||
testFeatures: testIdx.map(function (index) { | ||
return features[index]; | ||
}), | ||
trainFeatures: trainIdx.map(function (index) { | ||
return features[index]; | ||
}), | ||
testLabels: testIdx.map(function (index) { | ||
return labels[index]; | ||
}), | ||
trainLabels: trainIdx.map(function (index) { | ||
return labels[index]; | ||
}) | ||
}; | ||
} | ||
module.exports = CV; |
@@ -6,9 +6,13 @@ 'use strict'; | ||
var LOO = require('./data/LOO-CV'); | ||
var LPO = require('./data/LPO-CV'); | ||
var KF = require('./data/KF-CV'); | ||
describe('basic', function () { | ||
it('basic leave-one-out cross-validation', function () { | ||
var LOO = require('./data/LOO-CV'); | ||
for (let i = 0; i < LOO.length; i++) { | ||
var CM = CV.leaveOneOut(Dummy, LOO[i].features, LOO[i].labels, LOO[i].classifierOptions); | ||
CM.matrix.should.deepEqual(LOO[i].result.matrix); | ||
CM.labels.should.deepEqual(LOO[i].result.labels); | ||
CM.getMatrix().should.deepEqual(LOO[i].result.matrix); | ||
CM.getLabels().should.deepEqual(LOO[i].result.labels); | ||
} | ||
@@ -18,7 +22,6 @@ }); | ||
it('basic leave-p-out cross-validation', function () { | ||
var LPO = require('./data/LPO-CV'); | ||
for (let i = 0; i < LPO.length; i++) { | ||
var CM = CV.leavePOut(Dummy, LPO[i].features, LPO[i].labels, LPO[i].classifierOptions, LPO[i].p); | ||
CM.matrix.should.deepEqual(LPO[i].result.matrix); | ||
CM.labels.should.deepEqual(LPO[i].result.labels); | ||
CM.getMatrix().should.deepEqual(LPO[i].result.matrix); | ||
CM.getLabels().should.deepEqual(LPO[i].result.labels); | ||
} | ||
@@ -28,9 +31,46 @@ }); | ||
it('basic k-fold cross-validation', function () { | ||
var KF = require('./data/KF-CV'); | ||
for (let i = 0; i < KF.length; i++) { | ||
var CM = CV.kFold(Dummy, KF[i].features, KF[i].labels, KF[i].classifierOptions, KF[i].k); | ||
CM.matrix.should.deepEqual(KF[i].result.matrix); | ||
CM.labels.should.deepEqual(KF[i].result.labels); | ||
CM.getMatrix().should.deepEqual(KF[i].result.matrix); | ||
CM.getLabels().should.deepEqual(KF[i].result.labels); | ||
} | ||
}); | ||
}); | ||
describe('with a callback', function () { | ||
it('basic leave-on-out cross-validation with callback', function () { | ||
for (let i = 0; i < LOO.length; i++) { | ||
var CM = CV.leaveOneOut(LOO[i].features, LOO[i].labels, function (trainFeatures, trainLabels, testFeatures) { | ||
const classifier = new Dummy(LOO[i].classifierOptions); | ||
classifier.train(trainFeatures, trainLabels); | ||
return classifier.predict(testFeatures); | ||
}); | ||
CM.getMatrix().should.deepEqual(LOO[i].result.matrix); | ||
CM.getLabels().should.deepEqual(LOO[i].result.labels); | ||
} | ||
}); | ||
it('basic leave-p-out cross-validation with callback', function () { | ||
for (let i = 0; i < LPO.length; i++) { | ||
var CM = CV.leavePOut(LPO[i].features, LPO[i].labels, LPO[i].p, function (trainFeatures, trainLabels, testFeatures) { | ||
const classifier = new Dummy(LPO[i].classifierOptions); | ||
classifier.train(trainFeatures, trainLabels); | ||
return classifier.predict(testFeatures); | ||
}); | ||
CM.getMatrix().should.deepEqual(LPO[i].result.matrix); | ||
CM.getLabels().should.deepEqual(LPO[i].result.labels); | ||
} | ||
}); | ||
it('basic k-fold cross-validation with callback', function () { | ||
for (let i = 0; i < KF.length; i++) { | ||
var CM = CV.kFold(KF[i].features, KF[i].labels, KF[i].k, function (trainFeatures, trainLabels, testFeatures) { | ||
const classifier = new Dummy(KF[i].classifierOptions); | ||
classifier.train(trainFeatures, trainLabels); | ||
return classifier.predict(testFeatures); | ||
}); | ||
CM.getMatrix().should.deepEqual(KF[i].result.matrix); | ||
CM.getLabels().should.deepEqual(KF[i].result.labels); | ||
} | ||
}); | ||
}); |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
Major refactor
Supply chain riskPackage has recently undergone a major refactor. It may be unstable or indicate significant internal changes. Use caution when updating to versions that include significant changes.
Found 1 instance in 1 package
1608569
1276
77
4
44