Huge News!Announcing our $40M Series B led by Abstract Ventures.Learn More
Socket
Sign inDemoInstall
Socket

ml-cross-validation

Package Overview
Dependencies
Maintainers
7
Versions
5
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

ml-cross-validation - npm Package Compare versions

Comparing version 1.1.0 to 1.2.0

10

History.md

@@ -0,1 +1,11 @@

<a name="1.2.0"></a>
# [1.2.0](https://github.com/mljs/cross-validation/compare/v1.1.0...v1.2.0) (2017-11-08)
### Features
* allow passing a callback to cross-validation methods ([32501e2](https://github.com/mljs/cross-validation/commit/32501e2))
<a name="1.1.0"></a>

@@ -2,0 +12,0 @@ # [1.1.0](https://github.com/mljs/cross-validation/compare/v1.0.1...v1.1.0) (2017-07-07)

4

package.json
{
"name": "ml-cross-validation",
"version": "1.1.0",
"version": "1.2.0",
"description": "Cross validation utility for mljs classifiers",

@@ -33,3 +33,3 @@ "main": "src",

"eslint-plugin-no-only-tests": "^2.0.0",
"mocha": "^3.1.2",
"mocha": "^3.5.3",
"mocha-better-spec-reporter": "^3.0.2",

@@ -36,0 +36,0 @@ "should": "^11.1.1"

@@ -7,9 +7,66 @@ # cross-validation

Utility library to do cross validation with mljs classifiers
Utility library to do cross validation with supervised classifiers.
A list of the mljs supervised classifiers is available [here](https://github.com/mljs/ml#tools) in the supervised learning section
Cross-validation methods:
- [k-fold](https://en.wikipedia.org/wiki/Cross-validation_(statistics)#k-fold_cross-validation)
- [leave-p-out](https://en.wikipedia.org/wiki/Cross-validation_(statistics)#Leave-p-out_cross-validation)
[Documentation](https://mljs.github.io/cross-validation/)
[API documentation](https://mljs.github.io/cross-validation/).
A list of the mljs supervised classifiers is available [here](https://github.com/mljs/ml#tools) in the supervised learning section, but you could also use your own. Cross validations methods return a ConfusionMatrix ([https://github.com/mljs/confusion-matrix](https://github.com/mljs/confusion-matrix)) that can be used to calculate metrics on your classification result.
## Installation
```bash
npm i -s ml-cross-validation
```
## Example using a ml classification library
```js
const crossValidation = require('ml-cross-validation');
const KNN = require('ml-knn');
const dataset = [[0, 0, 0], [0, 1, 1], [1, 1, 0], [2, 2, 2], [1, 2, 2], [2, 1, 2]];
const labels = [0, 0, 0, 1, 1, 1];
const confusionMatrix = crossValidation.leaveOneOut(KNN, dataSet, labels);
const accuracy = confusionMatrix.getAccuracy();
```
## Example using a classifier with its own specific API
If you have a library that does not comply with the ML Classifier conventions, you can use can use a callback to perform the classification.
The callback will take the train features and labels, and the test features. The callback shoud return the array of predicted labels.
```js
const crossValidation = require('ml-cross-validation');
const KNN = require('ml-knn');
const dataset = [[0, 0, 0], [0, 1, 1], [1, 1, 0], [2, 2, 2], [1, 2, 2], [2, 1, 2]];
const labels = [0, 0, 0, 1, 1, 1];
const confusionMatrix = crossValidation.leaveOneOut(dataSet, labels, function(trainFeatures, trainLabels, testFeatures) {
const knn = new KNN(trainFeatures, trainLabels);
return knn.predict(testFeatures);
});
const accuracy = confusionMatrix.getAccuracy();
```
## ML classifier API conventions
You can write your classification library so that it can be used with ml-cross-validation as described in [here](#example-using-a-ml-classification-library)
For that, your classification library must implement
- A constructor. The constructor can be passed options as a single argument.
- A `train` method. The `train` method is passed the data as a first argument and the labels as a second.
- A `predict` method. The `predict` method is passed test data and should return a predicted label.
### Example
```js
class MyClassifier {
constructor(options) {
this.options = options;
}
train(data, labels) {
// Create your model
}
predict(testData) {
// Apply your model and return predicted label
return prediction;
}
}
```
###
[npm-image]: https://img.shields.io/npm/v/ml-cross-validation.svg?style=flat-square

@@ -16,0 +73,0 @@ [npm-url]: https://npmjs.org/package/ml-cross-validation

@@ -9,6 +9,7 @@ 'use strict';

/**
* Performs a leave-one-out cross-validation (LOO-CV) of the given samples. In LOO-CV, 1 observation is used as the validation
* set while the rest is used as the training set. This is repeated once for each observation. LOO-CV is a special case
* of LPO-CV. @see leavePout
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier api.
* Performs a leave-one-out cross-validation (LOO-CV) of the given samples. In LOO-CV, 1 observation is used as the
* validation set while the rest is used as the training set. This is repeated once for each observation. LOO-CV is a
* special case of LPO-CV. @see leavePout
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier
* api.
* @param {Array} features - The features for all samples of the data-set

@@ -20,2 +21,8 @@ * @param {Array} labels - The classification class of all samples of the data-set

CV.leaveOneOut = function (Classifier, features, labels, classifierOptions) {
if (typeof labels === 'function') {
var callback = labels;
labels = features;
features = Classifier;
return CV.leavePOut(features, labels, 1, callback);
}
return CV.leavePOut(Classifier, features, labels, classifierOptions, 1);

@@ -30,3 +37,4 @@ };

* data-set size this can require a very large number of training and testing to do!
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier api.
* @param {function} Classifier - The classifier's constructor to use for the cross validation. Expect ml-classifier
* api.
* @param {Array} features - The features for all samples of the data-set

@@ -39,2 +47,8 @@ * @param {Array} labels - The classification class of all samples of the data-set

CV.leavePOut = function (Classifier, features, labels, classifierOptions, p) {
if (typeof classifierOptions === 'function') {
var callback = classifierOptions;
p = labels;
labels = features;
features = Classifier;
}
check(features, labels);

@@ -57,3 +71,8 @@ const distinct = getDistinct(labels);

validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct);
if (callback) {
validateWithCallback(features, labels, testIdx, trainIdx, confusionMatrix, distinct, callback);
} else {
validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct);
}
}

@@ -76,2 +95,8 @@

CV.kFold = function (Classifier, features, labels, classifierOptions, k) {
if (typeof classifierOptions === 'function') {
var callback = classifierOptions;
k = labels;
labels = features;
features = Classifier;
}
check(features, labels);

@@ -110,3 +135,7 @@ const distinct = getDistinct(labels);

validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct);
if (callback) {
validateWithCallback(features, labels, testIdx, trainIdx, confusionMatrix, distinct, callback);
} else {
validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct);
}
}

@@ -136,14 +165,3 @@

function validate(Classifier, features, labels, classifierOptions, testIdx, trainIdx, confusionMatrix, distinct) {
var testFeatures = testIdx.map(function (index) {
return features[index];
});
var trainFeatures = trainIdx.map(function (index) {
return features[index];
});
var testLabels = testIdx.map(function (index) {
return labels[index];
});
var trainLabels = trainIdx.map(function (index) {
return labels[index];
});
const {testFeatures, trainFeatures, testLabels, trainLabels} = getTrainTest(features, labels, testIdx, trainIdx);

@@ -159,7 +177,42 @@ var classifier;

var predictedLabels = classifier.predict(testFeatures);
updateConfusionMatrix(confusionMatrix, testLabels, predictedLabels, distinct);
}
function validateWithCallback(features, labels, testIdx, trainIdx, confusionMatrix, distinct, callback) {
const {testFeatures, trainFeatures, testLabels, trainLabels} = getTrainTest(features, labels, testIdx, trainIdx);
const predictedLabels = callback(trainFeatures, trainLabels, testFeatures);
updateConfusionMatrix(confusionMatrix, testLabels, predictedLabels, distinct);
}
function updateConfusionMatrix(confusionMatrix, testLabels, predictedLabels, distinct) {
for (var i = 0; i < predictedLabels.length; i++) {
confusionMatrix[distinct.indexOf(testLabels[i])][distinct.indexOf(predictedLabels[i])]++;
const actualIdx = distinct.indexOf(testLabels[i]);
const predictedIdx = distinct.indexOf(predictedLabels[i]);
if (actualIdx < 0 || predictedIdx < 0) {
// eslint-disable-next-line no-console
console.warn(`ignore unknown predicted label ${predictedLabels[i]}`);
}
confusionMatrix[actualIdx][predictedIdx]++;
}
}
function getTrainTest(features, labels, testIdx, trainIdx) {
return {
testFeatures: testIdx.map(function (index) {
return features[index];
}),
trainFeatures: trainIdx.map(function (index) {
return features[index];
}),
testLabels: testIdx.map(function (index) {
return labels[index];
}),
trainLabels: trainIdx.map(function (index) {
return labels[index];
})
};
}
module.exports = CV;

@@ -6,9 +6,13 @@ 'use strict';

var LOO = require('./data/LOO-CV');
var LPO = require('./data/LPO-CV');
var KF = require('./data/KF-CV');
describe('basic', function () {
it('basic leave-one-out cross-validation', function () {
var LOO = require('./data/LOO-CV');
for (let i = 0; i < LOO.length; i++) {
var CM = CV.leaveOneOut(Dummy, LOO[i].features, LOO[i].labels, LOO[i].classifierOptions);
CM.matrix.should.deepEqual(LOO[i].result.matrix);
CM.labels.should.deepEqual(LOO[i].result.labels);
CM.getMatrix().should.deepEqual(LOO[i].result.matrix);
CM.getLabels().should.deepEqual(LOO[i].result.labels);
}

@@ -18,7 +22,6 @@ });

it('basic leave-p-out cross-validation', function () {
var LPO = require('./data/LPO-CV');
for (let i = 0; i < LPO.length; i++) {
var CM = CV.leavePOut(Dummy, LPO[i].features, LPO[i].labels, LPO[i].classifierOptions, LPO[i].p);
CM.matrix.should.deepEqual(LPO[i].result.matrix);
CM.labels.should.deepEqual(LPO[i].result.labels);
CM.getMatrix().should.deepEqual(LPO[i].result.matrix);
CM.getLabels().should.deepEqual(LPO[i].result.labels);
}

@@ -28,9 +31,46 @@ });

it('basic k-fold cross-validation', function () {
var KF = require('./data/KF-CV');
for (let i = 0; i < KF.length; i++) {
var CM = CV.kFold(Dummy, KF[i].features, KF[i].labels, KF[i].classifierOptions, KF[i].k);
CM.matrix.should.deepEqual(KF[i].result.matrix);
CM.labels.should.deepEqual(KF[i].result.labels);
CM.getMatrix().should.deepEqual(KF[i].result.matrix);
CM.getLabels().should.deepEqual(KF[i].result.labels);
}
});
});
describe('with a callback', function () {
it('basic leave-on-out cross-validation with callback', function () {
for (let i = 0; i < LOO.length; i++) {
var CM = CV.leaveOneOut(LOO[i].features, LOO[i].labels, function (trainFeatures, trainLabels, testFeatures) {
const classifier = new Dummy(LOO[i].classifierOptions);
classifier.train(trainFeatures, trainLabels);
return classifier.predict(testFeatures);
});
CM.getMatrix().should.deepEqual(LOO[i].result.matrix);
CM.getLabels().should.deepEqual(LOO[i].result.labels);
}
});
it('basic leave-p-out cross-validation with callback', function () {
for (let i = 0; i < LPO.length; i++) {
var CM = CV.leavePOut(LPO[i].features, LPO[i].labels, LPO[i].p, function (trainFeatures, trainLabels, testFeatures) {
const classifier = new Dummy(LPO[i].classifierOptions);
classifier.train(trainFeatures, trainLabels);
return classifier.predict(testFeatures);
});
CM.getMatrix().should.deepEqual(LPO[i].result.matrix);
CM.getLabels().should.deepEqual(LPO[i].result.labels);
}
});
it('basic k-fold cross-validation with callback', function () {
for (let i = 0; i < KF.length; i++) {
var CM = CV.kFold(KF[i].features, KF[i].labels, KF[i].k, function (trainFeatures, trainLabels, testFeatures) {
const classifier = new Dummy(KF[i].classifierOptions);
classifier.train(trainFeatures, trainLabels);
return classifier.predict(testFeatures);
});
CM.getMatrix().should.deepEqual(KF[i].result.matrix);
CM.getLabels().should.deepEqual(KF[i].result.labels);
}
});
});

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is not supported yet

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap
  • Changelog

Packages

npm

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc