Comparing version
'use strict'; | ||
const jimp = require("jimp"); | ||
const assert = require('assert'); | ||
const fs = require('fs'); | ||
const ndarray = require('ndarray'); | ||
const dtype = require('dtype'); | ||
const _ = require('lodash'); | ||
@@ -32,6 +33,6 @@ const menoh = require('..'); // This menoh module | ||
var outp = []; | ||
for (var i = 0; i < a.length; i++) { | ||
for (var i = 0; i < a.size; i++) { | ||
outp.push(i); // add index to output array | ||
if (outp.length > k) { | ||
outp.sort((l, r) => { return a[r] - a[l]; }); | ||
outp.sort((l, r) => { return a.get(r) - a.get(l); }); | ||
outp.pop(); | ||
@@ -47,68 +48,67 @@ } | ||
loadInputImages() | ||
.then((imageList) => { | ||
const data = []; | ||
imageList.forEach((image, batchIdx) => { | ||
// All the input images are already croped and resized to 28 x 28. | ||
assert.equal(image.bitmap.width, 28); | ||
assert.equal(image.bitmap.height, 28); | ||
/* | ||
*/ | ||
// Convert bitmap to an array. (Use R channel only - already in greyscale) | ||
const numPixels = image.bitmap.width * image.bitmap.height; | ||
const batchOffset = batchIdx * numPixels; | ||
image.scan(0, 0, image.bitmap.width, image.bitmap.height, function (x, y, idx) { | ||
const dataIdx = y * image.bitmap.width + x + batchOffset; | ||
data[dataIdx] = this.bitmap.data[idx]; // R channel | ||
}); | ||
}); | ||
// Load ONNX file | ||
return menoh.create('../test/data/mnist/mnist.onnx') | ||
.then((builder) => { | ||
const batchSize = INPUT_IMAGE_LIST.length; | ||
// Load ONNX file | ||
return menoh.create('../test/data/mnist/mnist.onnx') | ||
.then((builder) => { | ||
const batchSize = imageList.length; | ||
// Add input data | ||
builder.addInput(MNIST_IN_NAME, [ | ||
batchSize, // 10 images in the data | ||
1, // number of channels | ||
28, // height | ||
28 // width | ||
]); | ||
// Add input data | ||
builder.addInput(MNIST_IN_NAME, [ | ||
batchSize, // 10 images in the data | ||
1, // number of channels | ||
28, // height | ||
28 // width | ||
]); | ||
// Add output | ||
builder.addOutput(MNIST_OUT_NAME); | ||
// Add output | ||
builder.addOutput(MNIST_OUT_NAME); | ||
// Build a new Model | ||
const model = builder.buildModel({ | ||
backendName: 'mkldnn' | ||
}) | ||
// Build a new Model | ||
const model = builder.buildModel({ | ||
backendName: 'mkldnn' | ||
}) | ||
// Create a view for input buffer using ndarray. | ||
const iData = (function () { | ||
const prof = model.getProfile(MNIST_IN_NAME); | ||
return ndarray(new (dtype(prof.dtype))(prof.buf.buffer), prof.dims); | ||
})(); | ||
// Set input data | ||
model.setInputData(MNIST_IN_NAME, data); | ||
// Create a view for output buffer using ndarray. | ||
const oData = (function () { | ||
const prof = model.getProfile(MNIST_OUT_NAME); | ||
return ndarray(new (dtype(prof.dtype))(prof.buf.buffer), prof.dims); | ||
})(); | ||
return loadInputImages() | ||
.then((imageList) => { | ||
imageList.forEach((image, batchIdx) => { | ||
// All the input images are already croped and resized to 28 x 28. | ||
// Now, copy the image data into to the input buffer in NCHW format. | ||
image.scan(0, 0, image.bitmap.width, image.bitmap.height, (x, y, idx) => { | ||
const val = image.bitmap.data[idx]; | ||
iData.set(batchIdx, 0, y, x, val); | ||
}); | ||
}); | ||
// Run the model | ||
return model.run() | ||
.then(() => { | ||
const out = model.getOutput(MNIST_OUT_NAME); | ||
// just to be sure | ||
assert.equal(out.dims[0] * out.dims[1], out.data.length); | ||
assert.equal(out.dims[0], batchSize); // only applies to this example | ||
// Print the results. | ||
out.data = _.chunk(out.data, out.dims[1]); // reshaped | ||
for (let bi = 0; bi < batchSize; ++bi) { | ||
console.log('### Result for %s', INPUT_IMAGE_LIST[bi]); | ||
const topK = findIndicesOfTopK(out.data[bi], 1); | ||
const topK = findIndicesOfTopK(oData.pick(bi, null), 1); | ||
topK.forEach((i) => { | ||
console.log('[%d] %f %s', i, out.data[bi][i], categoryList[i]); | ||
console.log('[%d] %f %s', i, oData.get(bi, i), categoryList[i]); | ||
}); | ||
} | ||
// Happily done! | ||
}); | ||
}); | ||
}) | ||
.catch((err) => { | ||
console.log('Error:', err); | ||
}); | ||
@@ -6,2 +6,4 @@ 'use strict'; | ||
const fs = require('fs'); | ||
const ndarray = require('ndarray'); | ||
const dtype = require('dtype'); | ||
const _ = require('lodash'); | ||
@@ -51,6 +53,6 @@ const menoh = require('..'); // This menoh module | ||
var outp = []; | ||
for (var i = 0; i < a.length; i++) { | ||
for (var i = 0; i < a.size; i++) { | ||
outp.push(i); // add index to output array | ||
if (outp.length > k) { | ||
outp.sort((l, r) => { return a[r] - a[l]; }); | ||
outp.sort((l, r) => { return a.get(r) - a.get(l); }); | ||
outp.pop(); | ||
@@ -66,79 +68,80 @@ } | ||
loadInputImages() | ||
.then((imageList) => { | ||
const data = []; | ||
imageList.forEach((image, batchIdx) => { | ||
// Crop the input image to a square shape. | ||
cropToSquare(image); | ||
// Load ONNX file | ||
return menoh.create('../test/data/vgg16/VGG16.onnx') | ||
.then((builder) => { | ||
const batchSize = INPUT_IMAGE_LIST.length; | ||
// Resize it to 224 x 224. | ||
image.resize(224, 224); | ||
// Add input | ||
builder.addInput(CONV1_1_IN_NAME, [ | ||
batchSize, // 2 images in the data | ||
3, // number of channels | ||
224, // height | ||
224 // width | ||
]); | ||
// Convert bitmap to an array. | ||
const numPixels = image.bitmap.width * image.bitmap.height; | ||
const batchOffset = batchIdx * numPixels * 3; | ||
image.scan(0, 0, image.bitmap.width, image.bitmap.height, function (x, y, idx) { | ||
for (let c = 0; c < 3; ++c) { | ||
const dataIdx = c * numPixels + y * image.bitmap.width + x + batchOffset; | ||
data[dataIdx] = this.bitmap.data[idx + c]; | ||
} | ||
}); | ||
}); | ||
// Add output | ||
builder.addOutput(FC6_OUT_NAME); | ||
builder.addOutput(SOFTMAX_OUT_NAME); | ||
// Load ONNX file | ||
return menoh.create('../test/data/vgg16/VGG16.onnx') | ||
.then((builder) => { | ||
const batchSize = imageList.length; | ||
// Build a new Model | ||
const model = builder.buildModel({ | ||
backendName: 'mkldnn' | ||
}) | ||
// Add input | ||
builder.addInput(CONV1_1_IN_NAME, [ | ||
batchSize, // 2 images in the data | ||
3, // number of channels | ||
224, // height | ||
224 // width | ||
]); | ||
// Create a view for input buffer using ndarray. | ||
const iData = (function () { | ||
const prof = model.getProfile(CONV1_1_IN_NAME); | ||
return ndarray(new (dtype(prof.dtype))(prof.buf.buffer), prof.dims); | ||
})(); | ||
// Add output | ||
builder.addOutput(FC6_OUT_NAME); | ||
builder.addOutput(SOFTMAX_OUT_NAME); | ||
// Create a view for each output buffers using ndarray. | ||
const oDataFc6 = (function () { | ||
const prof = model.getProfile(FC6_OUT_NAME); | ||
return ndarray(new (dtype(prof.dtype))(prof.buf.buffer), prof.dims); | ||
})(); | ||
const oDataSmx = (function () { | ||
const prof = model.getProfile(SOFTMAX_OUT_NAME); | ||
return ndarray(new (dtype(prof.dtype))(prof.buf.buffer), prof.dims); | ||
})(); | ||
// Build a new Model | ||
const model = builder.buildModel({ | ||
backendName: 'mkldnn' | ||
}) | ||
return loadInputImages() | ||
.then((imageList) => { | ||
const data = []; | ||
imageList.forEach((image, batchIdx) => { | ||
// Crop the input image to a square shape. | ||
cropToSquare(image); | ||
// Set input data | ||
model.setInputData(CONV1_1_IN_NAME, data); | ||
// Resize it to 224 x 224. | ||
image.resize(224, 224); | ||
// Now, copy the image data into to the input buffer in NCHW format. | ||
image.scan(0, 0, image.bitmap.width, image.bitmap.height, (x, y, idx) => { | ||
for (let c = 0; c < 3; ++c) { | ||
const val = image.bitmap.data[idx + c]; | ||
iData.set(batchIdx, c, y, x, val); | ||
} | ||
}); | ||
}); | ||
// Run the model | ||
return model.run() | ||
.then(() => { | ||
const out1 = model.getOutput(FC6_OUT_NAME); | ||
const out2 = model.getOutput(SOFTMAX_OUT_NAME); | ||
// just to be sure | ||
assert.equal(out1.dims[0] * out1.dims[1], out1.data.length); | ||
assert.equal(out2.dims[0] * out2.dims[1], out2.data.length); | ||
assert.equal(out1.dims[0], batchSize); // only applies to this example | ||
assert.equal(out2.dims[0], batchSize); // only applies to this example | ||
// Print the results. | ||
out1.data = _.chunk(out1.data, out1.dims[1]); // reshaped | ||
out2.data = _.chunk(out2.data, out2.dims[1]); // reshaped | ||
for (let bi = 0; bi < batchSize; ++bi) { | ||
console.log('### Result for %s', INPUT_IMAGE_LIST[bi]); | ||
console.log('fc6 out: %s ...', out1.data[bi].slice(0, 5).join(' ')); | ||
const fc6 = oDataFc6.pick(bi, null); | ||
console.log('fc6 out: %s ...', [0, 1, 2].map((i) => fc6.get(i)).join(' ')); | ||
const topK = findIndicesOfTopK(out2.data[bi], 5); | ||
const topK = findIndicesOfTopK(oDataSmx.pick(bi, null), 5); | ||
console.log('Top 5 categories are:'); | ||
topK.forEach((i) => { | ||
console.log('[%d] %f %s', i, out2.data[bi][i], categoryList[i]); | ||
console.log('[%d] %f %s', i, oDataSmx.get(bi, i), categoryList[i]); | ||
}); | ||
} | ||
// Happily done! | ||
}); | ||
}); | ||
}) | ||
.catch((err) => { | ||
console.log('Error:', err); | ||
}); | ||
{ | ||
"name": "menoh", | ||
"version": "1.0.1", | ||
"version": "1.1.0", | ||
"description": "NodeJS binding for Menoh DNN inference library.", | ||
@@ -18,3 +18,4 @@ "main": "index.js", | ||
"dnn", | ||
"mkl-dnn" | ||
"mkl-dnn", | ||
"chainer" | ||
], | ||
@@ -29,4 +30,6 @@ "author": "enobufs", | ||
"devDependencies": { | ||
"dtype": "^2.0.0", | ||
"jimp": "^0.2.28", | ||
"mocha": "^5.0.2" | ||
"mocha": "^5.0.2", | ||
"ndarray": "^1.0.18" | ||
}, | ||
@@ -33,0 +36,0 @@ "engines": { |
# menoh | ||
NodeJS binding for Menoh DNN inference library. | ||
## Features | ||
* Fast DNN inference on Intel CPU. | ||
* Support standard [ONNX](http://onnx.ai/) format. | ||
* Easy to use. | ||
## Requirements | ||
* MKL-DNN library [v0.14](https://github.com/intel/mkl-dnn/tree/v0.14) or later. | ||
* ProtocolBuffers (Only tested with v3.5.1) | ||
* [Menoh(C/C++) library](https://github.com/pfnet-research/menoh) v1.x (Only tested with v1.0.2) | ||
* ProtocolBuffers (Tested with v3.5.1) | ||
* [Menoh(C/C++) library](https://github.com/pfnet-research/menoh) v1.x (Tested with v1.0.2) | ||
* NodeJS v6 or greater | ||
@@ -25,3 +30,3 @@ | ||
For linux, you may need add `/usr/local/lib` to LD_LIBRARY_PATH depending on your linux distrubtion. | ||
For linux, you may need to add `/usr/local/lib` to LD_LIBRARY_PATH depending on your linux distrubtion. | ||
```sh | ||
@@ -139,5 +144,11 @@ export LD_LIBRARY_PATH=/usr/local/lib | ||
### Model methods | ||
#### model.setInputData(input_var_name{string}, data{array}) | ||
Sets input data for the give input name. | ||
#### model.getProfile(var_name{string}) => {object} | ||
Returns an object for the given name. | ||
The object has following properties: | ||
* dims {array}: Output buffer dimensions. (e.g. [1, 3, 244, 244]) | ||
* buf {Buffer}: Output buffer attached to the variable. | ||
* dtype {string}: Data type. | ||
> Current revision supports only one data type, "float32". | ||
#### model.run(cb) => {Promise} | ||
@@ -148,3 +159,10 @@ Run inference. It returns promise if `cb` is not provided. The actual inference takes place | ||
#### model.setInputData(input_var_name{string}, data{array}) | ||
> DEPREACATED. Use model.getProfile() instead. | ||
Sets input data for the give input name. | ||
#### model.getOutput(output_var_name) => {object} | ||
> DEPREACATED. Use model.getProfile() instead. | ||
Returns output object generated during `model.run()` for the given output name. | ||
@@ -151,0 +169,0 @@ The output object has following properties: |
@@ -5,2 +5,4 @@ 'use strict'; | ||
const jimp = require("jimp"); | ||
const ndarray = require("ndarray"); | ||
const dtype = require("dtype"); | ||
const fs = require('fs'); | ||
@@ -48,11 +50,26 @@ const _ = require('lodash'); | ||
function createBufferView(model, name) { | ||
const prof = model.getProfile(name); | ||
return ndarray(new (dtype(prof.dtype))(prof.buf.buffer), prof.dims); | ||
} | ||
// Find the indexes of the k largest values. | ||
function findIndicesOfTopK(a, k) { | ||
var outp = []; | ||
for (var i = 0; i < a.length; i++) { | ||
outp.push(i); // add index to output array | ||
if (outp.length > k) { | ||
outp.sort((l, r) => { return a[r] - a[l]; }); | ||
outp.pop(); | ||
if (Array.isArray(a)) { | ||
for (var i = 0; i < a.length; i++) { | ||
outp.push(i); // add index to output array | ||
if (outp.length > k) { | ||
outp.sort((l, r) => { return a[r] - a[l]; }); | ||
outp.pop(); | ||
} | ||
} | ||
} else { | ||
for (var i = 0; i < a.size; i++) { | ||
outp.push(i); // add index to output array | ||
if (outp.length > k) { | ||
outp.sort((l, r) => { return a.get(r) - a.get(l); }); | ||
outp.pop(); | ||
} | ||
} | ||
} | ||
@@ -62,16 +79,25 @@ return outp; | ||
function validateOutput(model, batchSize) { | ||
const out = model.getOutput(MNIST_OUT_NAME); | ||
// sanity check | ||
assert.equal(out.dims[0] * out.dims[1], out.data.length); | ||
assert.equal(out.dims[0], batchSize); | ||
function validateOutput(output, batchSize) { | ||
if (Array.isArray(output.data)) { | ||
// sanity check | ||
assert.equal(output.dims[0] * output.dims[1], output.data.length); | ||
assert.equal(output.dims[0], batchSize); | ||
// Evaluate results. | ||
out.data = _.chunk(out.data, out.dims[1]); // reshaped | ||
for (let bi = 0; bi < batchSize; ++bi) { | ||
const topK = findIndicesOfTopK(out.data[bi], 1); | ||
topK.forEach((idx) => { | ||
assert.equal(idx, bi); | ||
}); | ||
// Evaluate results. | ||
output.data = _.chunk(output.data, output.dims[1]); // reshaped | ||
for (let bi = 0; bi < batchSize; ++bi) { | ||
const topK = findIndicesOfTopK(output.data[bi], 1); | ||
topK.forEach((idx) => { | ||
assert.equal(idx, bi); | ||
}); | ||
} | ||
} else { | ||
for (let bi = 0; bi < batchSize; ++bi) { | ||
const topK = findIndicesOfTopK(output.pick(bi, null), 1); | ||
topK.forEach((idx) => { | ||
assert.equal(idx, bi); | ||
}); | ||
} | ||
} | ||
} | ||
@@ -113,8 +139,19 @@ | ||
model.setInputData(MNIST_IN_NAME, data); | ||
const iv = createBufferView(model, MNIST_IN_NAME); | ||
const ov = createBufferView(model, MNIST_OUT_NAME); | ||
assert.equal(iv.size, data.length); | ||
assert.deepEqual(iv.shape, [batchSize, 1, 28, 28]); | ||
assert.equal(ov.size, 100); | ||
assert.deepEqual(ov.shape, [batchSize, 10]); | ||
// Write input data to input view. | ||
data.forEach((v, i) => { | ||
iv.data[i] = v; | ||
}); | ||
// Run the model | ||
model.run((err) => { | ||
assert.ifError(err); | ||
validateOutput(model, batchSize); | ||
validateOutput(ov, batchSize); | ||
done(); | ||
@@ -125,3 +162,3 @@ }); | ||
it('Succeed with promise', function () { | ||
it('Succeed using views', function () { | ||
// Load ONNX file | ||
@@ -140,8 +177,19 @@ return menoh.create(ONNX_FILE_PATH) | ||
model.setInputData(MNIST_IN_NAME, data); | ||
const iv = createBufferView(model, MNIST_IN_NAME); | ||
const ov = createBufferView(model, MNIST_OUT_NAME); | ||
assert.equal(iv.size, data.length); | ||
assert.deepEqual(iv.shape, [batchSize, 1, 28, 28]); | ||
assert.equal(ov.size, 100); | ||
assert.deepEqual(ov.shape, [batchSize, 10]); | ||
// Write input data to input view. | ||
data.forEach((v, i) => { | ||
iv.data[i] = v; | ||
}); | ||
// Run the model | ||
return model.run() | ||
.then(() => { | ||
validateOutput(model, batchSize); | ||
validateOutput(ov, batchSize); | ||
}); | ||
@@ -165,16 +213,21 @@ }); | ||
model.setInputData(MNIST_IN_NAME, data); | ||
const iv = createBufferView(model, MNIST_IN_NAME); | ||
const ov = createBufferView(model, MNIST_OUT_NAME); | ||
data.forEach((v, i) => { | ||
iv.data[i] = v; | ||
}); | ||
// Run the model | ||
return model.run() // 1st run | ||
.then(() => { | ||
validateOutput(model, batchSize); | ||
validateOutput(ov, batchSize); | ||
return model.run(); // 2nd run | ||
}) | ||
.then(() => { | ||
validateOutput(model, batchSize); | ||
validateOutput(ov, batchSize); | ||
return model.run(); // 3rd run | ||
}) | ||
.then(() => { | ||
validateOutput(model, batchSize); | ||
validateOutput(ov, batchSize); | ||
}); | ||
@@ -197,4 +250,6 @@ }); | ||
}) | ||
model1.setInputData(MNIST_IN_NAME, data); | ||
const iv1 = createBufferView(model1, MNIST_IN_NAME); | ||
const ov1 = createBufferView(model1, MNIST_OUT_NAME); | ||
// create model2 | ||
@@ -204,10 +259,15 @@ const model2 = builder.buildModel({ | ||
}) | ||
model2.setInputData(MNIST_IN_NAME, data); | ||
const iv2 = createBufferView(model2, MNIST_IN_NAME); | ||
const ov2 = createBufferView(model2, MNIST_OUT_NAME); | ||
data.forEach((v, i) => { | ||
iv1.data[i] = v; | ||
iv2.data[i] = v; | ||
}); | ||
// Run these models concurrently | ||
return Promise.all([ model1.run(), model2.run() ]) | ||
.then(() => { | ||
[model1, model2].forEach((model) => { | ||
validateOutput(model, batchSize); | ||
}); | ||
validateOutput(ov1, batchSize); | ||
validateOutput(ov2, batchSize); | ||
}); | ||
@@ -431,4 +491,4 @@ }); | ||
describe('#setInputData tests', function () { | ||
it('should throw with invalid input data', function () { | ||
describe('#run tests', function () { | ||
it('should throw with invalid arg 1', function () { | ||
return menoh.create(ONNX_FILE_PATH) | ||
@@ -442,11 +502,17 @@ .then((builder) => { | ||
model.setInputData('bad_input_name', data); // should throw | ||
const iv = createBufferView(model, MNIST_IN_NAME); | ||
// Write input data to input view. | ||
data.forEach((v, i) => { | ||
iv.data[i] = v; | ||
}); | ||
model.run('bad') | ||
}) | ||
.then(assert.fail, (err) => { | ||
assert.ok(err instanceof Error); | ||
assert.ok(err.message.includes('variable not found')); | ||
assert.ok(err.message.includes('arg 1')); | ||
}); | ||
}); | ||
it('should throw if input data is too short', function () { | ||
it('second run() on the same model should fail', function () { | ||
return menoh.create(ONNX_FILE_PATH) | ||
@@ -460,12 +526,75 @@ .then((builder) => { | ||
const tooShort = [0, 1, 2]; | ||
model.setInputData(MNIST_IN_NAME, tooShort); // should throw | ||
const iv = createBufferView(model, MNIST_IN_NAME); | ||
const ov = createBufferView(model, MNIST_OUT_NAME); | ||
// Write input data to input view. | ||
data.forEach((v, i) => { | ||
iv.data[i] = v; | ||
}); | ||
let err1 = null; | ||
let err2 = null; | ||
return Promise.all([ | ||
model.run().catch((err) => { | ||
err1 = err; | ||
}), | ||
model.run().catch((err) => { | ||
err2 = err; | ||
}), | ||
]) | ||
.then(() => { | ||
assert.ok(!err1); | ||
assert.ok(err2 instanceof Error); | ||
assert.ok(err2.message.includes('in progress')); | ||
validateOutput(ov, batchSize); | ||
}); | ||
}) | ||
.then(assert.fail, (err) => { | ||
assert.ok(err instanceof Error); | ||
assert.ok(err.message.includes('too short')); | ||
}); | ||
}); | ||
}); | ||
describe('Deprecated feature tests', function () { | ||
let imageList; | ||
let batchSize; | ||
let data; | ||
before(function () { | ||
return loadInputImages(INPUT_IMAGE_LIST) | ||
.then((_imageList) => { | ||
imageList = _imageList; | ||
batchSize = imageList.length; | ||
data = preprocessImages(imageList); | ||
}); | ||
}); | ||
it('Succeed with setInputData and getOutput', function () { | ||
// Load ONNX file | ||
return menoh.create(ONNX_FILE_PATH) | ||
.then((builder) => { | ||
const batchSize = imageList.length; | ||
builder.addInput(MNIST_IN_NAME, [ batchSize, 1, 28, 28 ]); | ||
builder.addOutput(MNIST_OUT_NAME); | ||
// Make a new Model | ||
const model = builder.buildModel({ | ||
backendName: 'mkldnn' | ||
}) | ||
model.setInputData(MNIST_IN_NAME, data); | ||
// Run the model | ||
return model.run() | ||
.then(() => { | ||
const out = model.getOutput(MNIST_OUT_NAME); | ||
validateOutput(out, batchSize); | ||
}); | ||
}); | ||
}); | ||
it('should throw if input data is too long', function () { | ||
describe('#setInputData failure tests', function () { | ||
it('should throw with invalid input data', function () { | ||
return menoh.create(ONNX_FILE_PATH) | ||
@@ -479,14 +608,11 @@ .then((builder) => { | ||
const tooLong = data.concat([0.666]); | ||
model.setInputData(MNIST_IN_NAME, tooLong); // should throw | ||
model.setInputData('bad_input_name', data); // should throw | ||
}) | ||
.then(assert.fail, (err) => { | ||
assert.ok(err instanceof Error); | ||
assert.ok(err.message.includes('too long')); | ||
assert.ok(err.message.includes('variable not found')); | ||
}); | ||
}); | ||
}); | ||
describe('#run tests', function () { | ||
it('should throw with invalid arg 1', function () { | ||
it('should throw if input data is too short', function () { | ||
return menoh.create(ONNX_FILE_PATH) | ||
@@ -500,12 +626,12 @@ .then((builder) => { | ||
model.setInputData(MNIST_IN_NAME, data); | ||
model.run('bad') | ||
const tooShort = [0, 1, 2]; | ||
model.setInputData(MNIST_IN_NAME, tooShort); // should throw | ||
}) | ||
.then(assert.fail, (err) => { | ||
assert.ok(err instanceof Error); | ||
assert.ok(err.message.includes('arg 1')); | ||
assert.ok(err.message.includes('too short')); | ||
}); | ||
}); | ||
it('second run() on the same model should fail', function () { | ||
it('should throw if input data is too long', function () { | ||
return menoh.create(ONNX_FILE_PATH) | ||
@@ -519,27 +645,13 @@ .then((builder) => { | ||
model.setInputData(MNIST_IN_NAME, data); | ||
let err1 = null; | ||
let err2 = null; | ||
return Promise.all([ | ||
model.run().catch((err) => { | ||
err1 = err; | ||
}), | ||
model.run().catch((err) => { | ||
err2 = err; | ||
}), | ||
]) | ||
.then(() => { | ||
assert.ok(!err1); | ||
assert.ok(err2 instanceof Error); | ||
assert.ok(err2.message.includes('in progress')); | ||
validateOutput(model, batchSize); | ||
}); | ||
const tooLong = data.concat([0.666]); | ||
model.setInputData(MNIST_IN_NAME, tooLong); // should throw | ||
}) | ||
.then(assert.fail, (err) => { | ||
assert.ok(err instanceof Error); | ||
assert.ok(err.message.includes('too long')); | ||
}); | ||
}); | ||
}); | ||
describe('#getOutput tests', function () { | ||
describe('#getOutput failure tests', function () { | ||
it('should throw with invalid output name', function () { | ||
@@ -546,0 +658,0 @@ return menoh.create(ONNX_FILE_PATH) |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
7291605
0.08%864
13.39%174
11.54%4
100%