@tensorflow/tfjs-converter
Advanced tools
Comparing version 0.5.9 to 0.6.0
@@ -22,2 +22,3 @@ import * as tfc from '@tensorflow/tfjs-core'; | ||
execute(inputs: tfc.Tensor | tfc.Tensor[] | tfc.NamedTensorMap, outputs?: string | string[]): tfc.Tensor | tfc.Tensor[]; | ||
private execute_(inputs, strictInputCheck?, outputs?); | ||
executeAsync(inputs: tfc.Tensor | tfc.Tensor[] | tfc.NamedTensorMap, outputs?: string | string[]): Promise<tfc.Tensor | tfc.Tensor[]>; | ||
@@ -24,0 +25,0 @@ private convertTensorMapToTensorsMap(map); |
@@ -134,3 +134,3 @@ "use strict"; | ||
FrozenModel.prototype.predict = function (inputs, config) { | ||
return this.execute(inputs, this.outputNodes); | ||
return this.execute_(inputs, true, this.outputNodes); | ||
}; | ||
@@ -150,2 +150,6 @@ FrozenModel.prototype.constructTensorMap = function (inputs) { | ||
FrozenModel.prototype.execute = function (inputs, outputs) { | ||
return this.execute_(inputs, false, outputs); | ||
}; | ||
FrozenModel.prototype.execute_ = function (inputs, strictInputCheck, outputs) { | ||
if (strictInputCheck === void 0) { strictInputCheck = true; } | ||
outputs = outputs || this.outputNodes; | ||
@@ -159,3 +163,3 @@ if (inputs instanceof tfc.Tensor || Array.isArray(inputs)) { | ||
} | ||
var result = this.executor.execute(this.convertTensorMapToTensorsMap(inputs), outputs); | ||
var result = this.executor.execute(this.convertTensorMapToTensorsMap(inputs), strictInputCheck, outputs); | ||
var keys = Object.keys(result); | ||
@@ -162,0 +166,0 @@ return (Array.isArray(outputs) && outputs.length > 1) ? |
@@ -5,3 +5,3 @@ import { NamedTensorMap, NamedTensorsMap, TensorInfo } from '../data/types'; | ||
private graph; | ||
private compiledOrder; | ||
private compiledMap; | ||
private _weightMap; | ||
@@ -11,2 +11,3 @@ private weightIds; | ||
private _outputs; | ||
private SEPERATOR; | ||
weightMap: NamedTensorsMap; | ||
@@ -20,12 +21,14 @@ readonly inputs: TensorInfo[]; | ||
readonly isDynamicShapeModel: boolean; | ||
private compile(); | ||
execute(inputs: NamedTensorsMap, outputs?: string | string[]): NamedTensorMap; | ||
private compile(startNodes?); | ||
execute(inputs: NamedTensorsMap, strictInputCheck?: boolean, outputs?: string | string[]): NamedTensorMap; | ||
executeAsync(inputs: NamedTensorsMap, outputs?: string | string[]): Promise<NamedTensorMap>; | ||
private executeWithControlFlow(inputs, context); | ||
private processStack(stack, context, tensorMap, added); | ||
private processStack(inputNodes, stack, context, tensorMap, added); | ||
private processChildNodes(node, stack, context, tensorMap, added); | ||
private calculateOutputs(outputs?); | ||
private findOutputs(tensorMap, context, outputs?); | ||
dispose(): void; | ||
private checkInputShapeAndType(inputs); | ||
private checkInput(inputs); | ||
private checkInputShapeAndType(inputs, strictInputCheck?); | ||
private checkInput(inputs, strictInputCheck?); | ||
private checkOutput(compiledNodes, outputs); | ||
} |
@@ -53,4 +53,5 @@ "use strict"; | ||
this.graph = graph; | ||
this.compiledOrder = []; | ||
this.compiledMap = new Map(); | ||
this._weightMap = {}; | ||
this.SEPERATOR = ','; | ||
this.placeholders = graph.placeholders; | ||
@@ -130,7 +131,14 @@ this._outputs = graph.outputs; | ||
}); | ||
GraphExecutor.prototype.compile = function () { | ||
GraphExecutor.prototype.compile = function (startNodes) { | ||
if (this.graph.withControlFlow || this.graph.withDynamicShape) { | ||
return; | ||
} | ||
var stack = this.graph.inputs.slice(); | ||
var compiledOrder = []; | ||
var inputs = startNodes || this.graph.placeholders; | ||
var sortedNodeNames = inputs.map(function (node) { return node.name; }).sort(); | ||
var nameKey = sortedNodeNames.join(this.SEPERATOR); | ||
if (this.compiledMap.get(nameKey)) { | ||
return; | ||
} | ||
var stack = inputs.concat(this.graph.weights); | ||
var visited = {}; | ||
@@ -140,3 +148,3 @@ while (stack.length > 0) { | ||
visited[node.name] = true; | ||
this.compiledOrder.push(node); | ||
compiledOrder.push(node); | ||
node.children.forEach(function (childNode) { | ||
@@ -151,15 +159,29 @@ if (!visited[childNode.name] && childNode.inputNames.every(function (name) { | ||
} | ||
this.compiledMap.set(nameKey, compiledOrder); | ||
}; | ||
GraphExecutor.prototype.execute = function (inputs, outputs) { | ||
GraphExecutor.prototype.execute = function (inputs, strictInputCheck, outputs) { | ||
var _this = this; | ||
this.checkInput(inputs); | ||
this.checkInputShapeAndType(inputs); | ||
if (strictInputCheck === void 0) { strictInputCheck = true; } | ||
var names = Object.keys(inputs).sort(); | ||
this.checkInput(inputs, strictInputCheck); | ||
this.checkInputShapeAndType(inputs, strictInputCheck); | ||
this.compile(names.map(function (name) { return _this.graph.nodes[name]; })); | ||
var outputNames = this.calculateOutputs(outputs); | ||
this.checkOutput(this.compiledMap.get(names.join(this.SEPERATOR)), outputNames); | ||
var tensorArrayMap = {}; | ||
var result = tfjs_core_1.tidy(function () { | ||
var context = new execution_context_1.ExecutionContext(_this._weightMap, tensorArrayMap); | ||
var tensors = _this.compiledOrder.reduce(function (map, node) { | ||
map[node.name] = operation_executor_1.executeOp(node, map, context); | ||
return map; | ||
}, __assign({}, _this.weightMap, inputs)); | ||
return _this.findOutputs(tensors, context, outputs); | ||
var tensorMap = __assign({}, _this.weightMap, inputs); | ||
var compiledNodes = _this.compiledMap.get(names.join(_this.SEPERATOR)); | ||
for (var i = 0; i < compiledNodes.length; i++) { | ||
var node = compiledNodes[i]; | ||
if (!tensorMap[node.name]) { | ||
tensorMap[node.name] = | ||
operation_executor_1.executeOp(node, tensorMap, context); | ||
} | ||
if (outputNames.every(function (name) { return !!tensorMap[name]; })) { | ||
break; | ||
} | ||
} | ||
return _this.findOutputs(tensorMap, context, outputNames); | ||
}); | ||
@@ -175,4 +197,4 @@ return result; | ||
case 0: | ||
this.checkInput(inputs); | ||
this.checkInputShapeAndType(inputs); | ||
this.checkInput(inputs, false); | ||
this.checkInputShapeAndType(inputs, false); | ||
tensorArrayMap = {}; | ||
@@ -204,7 +226,10 @@ context = new execution_context_1.ExecutionContext(this._weightMap, tensorArrayMap); | ||
return __awaiter(this, void 0, void 0, function () { | ||
var stack, tensorMap, added, promises; | ||
var _this = this; | ||
var names, inputNodes, stack, tensorMap, added, promises; | ||
return __generator(this, function (_a) { | ||
switch (_a.label) { | ||
case 0: | ||
stack = this.graph.inputs.map(function (node) { | ||
names = Object.keys(inputs); | ||
inputNodes = names.map(function (name) { return _this.graph.nodes[name]; }); | ||
stack = inputNodes.concat(this.graph.weights).map(function (node) { | ||
return { node: node, contexts: context.currentContext }; | ||
@@ -217,3 +242,3 @@ }); | ||
if (!(stack.length > 0)) return [3, 3]; | ||
promises = this.processStack(stack, context, tensorMap, added); | ||
promises = this.processStack(inputNodes, stack, context, tensorMap, added); | ||
return [4, Promise.all(promises)]; | ||
@@ -228,3 +253,3 @@ case 2: | ||
}; | ||
GraphExecutor.prototype.processStack = function (stack, context, tensorMap, added) { | ||
GraphExecutor.prototype.processStack = function (inputNodes, stack, context, tensorMap, added) { | ||
var _this = this; | ||
@@ -240,17 +265,22 @@ var promises = []; | ||
} | ||
var tensors = operation_executor_1.executeOp(item.node, tensorMap, context); | ||
if (!nodeName) { | ||
nodeName = utils_1.getNodeNameAndIndex(item.node.name, context)[0]; | ||
if (inputNodes.indexOf(item.node) === -1) { | ||
var tensors = operation_executor_1.executeOp(item.node, tensorMap, context); | ||
if (!nodeName) { | ||
nodeName = utils_1.getNodeNameAndIndex(item.node.name, context)[0]; | ||
} | ||
var currentContext_1 = context.currentContext; | ||
if (tensors instanceof Promise) { | ||
promises.push(tensors.then(function (t) { | ||
tensorMap[nodeName] = t; | ||
context.currentContext = currentContext_1; | ||
_this.processChildNodes(item.node, stack, context, tensorMap, added); | ||
return t; | ||
})); | ||
} | ||
else { | ||
tensorMap[nodeName] = tensors; | ||
this_1.processChildNodes(item.node, stack, context, tensorMap, added); | ||
} | ||
} | ||
var currentContext = context.currentContext; | ||
if (tensors instanceof Promise) { | ||
promises.push(tensors.then(function (t) { | ||
tensorMap[nodeName] = t; | ||
context.currentContext = currentContext; | ||
_this.processChildNodes(item.node, stack, context, tensorMap, added); | ||
return t; | ||
})); | ||
} | ||
else { | ||
tensorMap[nodeName] = tensors; | ||
this_1.processChildNodes(item.node, stack, context, tensorMap, added); | ||
@@ -286,7 +316,10 @@ } | ||
}; | ||
GraphExecutor.prototype.findOutputs = function (tensorMap, context, outputs) { | ||
GraphExecutor.prototype.calculateOutputs = function (outputs) { | ||
if (outputs && !(outputs instanceof Array)) { | ||
outputs = [outputs]; | ||
} | ||
var requestedOutputs = (outputs || this.graph.outputs.map(function (node) { return node.name; })); | ||
return (outputs || this.graph.outputs.map(function (node) { return node.name; })); | ||
}; | ||
GraphExecutor.prototype.findOutputs = function (tensorMap, context, outputs) { | ||
var requestedOutputs = this.calculateOutputs(outputs); | ||
return requestedOutputs.reduce(function (map, name) { | ||
@@ -302,5 +335,10 @@ map[name] = utils_1.getTensor(name, tensorMap, context); | ||
}; | ||
GraphExecutor.prototype.checkInputShapeAndType = function (inputs) { | ||
GraphExecutor.prototype.checkInputShapeAndType = function (inputs, strictInputCheck) { | ||
if (strictInputCheck === void 0) { strictInputCheck = true; } | ||
this.placeholders.forEach(function (node) { | ||
var input = inputs[node.name][0]; | ||
var inputTensors = inputs[node.name]; | ||
if (!strictInputCheck && !inputTensors) { | ||
return; | ||
} | ||
var input = inputTensors[0]; | ||
if (node.params['shape'] && node.params['shape'].value) { | ||
@@ -317,4 +355,5 @@ var shape_1 = node.params['shape'].value; | ||
}; | ||
GraphExecutor.prototype.checkInput = function (inputs) { | ||
GraphExecutor.prototype.checkInput = function (inputs, strictInputCheck) { | ||
var _this = this; | ||
if (strictInputCheck === void 0) { strictInputCheck = true; } | ||
var inputKeys = Object.keys(inputs); | ||
@@ -331,7 +370,8 @@ var missing = []; | ||
}); | ||
if (missing.length > 0) { | ||
var notInGraph = extra.filter(function (name) { return !_this.graph.nodes[name]; }); | ||
if (missing.length > 0 && strictInputCheck) { | ||
throw new Error("The dict provided in model.execute(dict) has the keys " + | ||
("[" + inputKeys + "], but is missing the required keys: [" + missing + "].")); | ||
} | ||
if (extra.length > 0) { | ||
if (extra.length > 0 && strictInputCheck) { | ||
throw new Error("The dict provided in model.execute(dict) has " + | ||
@@ -341,3 +381,19 @@ ("unused keys: [" + extra + "]. Please provide only the following keys: ") + | ||
} | ||
if (notInGraph.length > 0) { | ||
throw new Error("The dict provided in model.execute(dict) has " + | ||
("keys: [" + notInGraph + "] not part of model graph.")); | ||
} | ||
}; | ||
GraphExecutor.prototype.checkOutput = function (compiledNodes, outputs) { | ||
var compiledNodeNames = compiledNodes.map(function (node) { return node.name; }); | ||
var extra = []; | ||
outputs.forEach(function (name) { | ||
if (compiledNodeNames.indexOf(name) === -1) | ||
extra.push(name); | ||
}); | ||
if (extra.length > 0) { | ||
throw new Error("The following outputs are not be generated by the execution: " + | ||
("[" + extra + "].")); | ||
} | ||
}; | ||
return GraphExecutor; | ||
@@ -344,0 +400,0 @@ }()); |
@@ -22,2 +22,5 @@ "use strict"; | ||
return [tfc.tensor1d(utils_1.getParamValue('x', node, tensorMap, context).shape, 'int32')]; | ||
case 'shapeN': | ||
return utils_1.getParamValue('x', node, tensorMap, context) | ||
.map(function (t) { return tfc.tensor1d(t.shape); }); | ||
case 'size': | ||
@@ -24,0 +27,0 @@ return [tfc.scalar(utils_1.getParamValue('x', node, tensorMap, context).size, 'int32')]; |
@@ -54,2 +54,13 @@ "use strict"; | ||
{ | ||
'tfOpName': 'ShapeN', | ||
'dlOpName': 'shapeN', | ||
'category': 'graph', | ||
'params': [{ | ||
'tfInputIndex': 0, | ||
'tfInputParamLength': 0, | ||
'dlParamName': 'x', | ||
'type': 'tensors' | ||
}] | ||
}, | ||
{ | ||
'tfOpName': 'Print', | ||
@@ -56,0 +67,0 @@ 'dlOpName': 'print', |
@@ -54,2 +54,3 @@ "use strict"; | ||
var placeholders = []; | ||
var weights = []; | ||
var nodes = tfNodes.reduce(function (map, node) { | ||
@@ -63,2 +64,4 @@ map[node.name] = _this.mapNode(node); | ||
placeholders.push(map[node.name]); | ||
if (node.op === 'Const') | ||
weights.push(map[node.name]); | ||
return map; | ||
@@ -87,2 +90,3 @@ }, {}); | ||
outputs: outputs, | ||
weights: weights, | ||
placeholders: placeholders, | ||
@@ -89,0 +93,0 @@ withControlFlow: withControlFlow, |
@@ -39,2 +39,3 @@ import { Tensor } from '@tensorflow/tfjs-core'; | ||
outputs: Node[]; | ||
weights: Node[]; | ||
withControlFlow: boolean; | ||
@@ -41,0 +42,0 @@ withDynamicShape: boolean; |
@@ -1,2 +0,2 @@ | ||
declare const version = "0.5.9"; | ||
declare const version = "0.6.0"; | ||
export { version }; |
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
var version = '0.5.9'; | ||
var version = '0.6.0'; | ||
exports.version = version; | ||
//# sourceMappingURL=version.js.map |
{ | ||
"name": "@tensorflow/tfjs-converter", | ||
"version": "0.5.9", | ||
"version": "0.6.0", | ||
"description": "Tensorflow model converter for javascript", | ||
@@ -17,6 +17,6 @@ "main": "dist/src/index.js", | ||
"peerDependencies": { | ||
"@tensorflow/tfjs-core": "~0.12.15" | ||
"@tensorflow/tfjs-core": "~0.13.0" | ||
}, | ||
"devDependencies": { | ||
"@tensorflow/tfjs-core": "~0.12.15", | ||
"@tensorflow/tfjs-core": "~0.13.0", | ||
"@types/jasmine": "~2.8.6", | ||
@@ -23,0 +23,0 @@ "@types/node-fetch": "1.6.9", |
@@ -110,3 +110,3 @@ [![Build Status](https://travis-ci.org/tensorflow/tfjs-converter.svg?branch=master)](https://travis-ci.org/tensorflow/tfjs-converter) | ||
* `web_model.pb` (the dataflow graph) | ||
* `tensorflowjs_model.pb` (the dataflow graph) | ||
* `weights_manifest.json` (weight manifest file) | ||
@@ -138,3 +138,3 @@ * `group1-shard\*of\*` (collection of binary weight files) | ||
const MODEL_URL = 'https://.../mobilenet/web_model.pb'; | ||
const MODEL_URL = 'https://.../mobilenet/tensorflowjs_model.pb'; | ||
const WEIGHTS_URL = 'https://.../mobilenet/weights_manifest.json'; | ||
@@ -171,3 +171,3 @@ | ||
const MODEL_PATH = 'file:///tmp/mobilenet/web_model.pb'; | ||
const MODEL_PATH = 'file:///tmp/mobilenet/tensorflowjs_model.pb'; | ||
const WEIGHTS_PATH = 'file:///tmp/mobilenet/weights_manifest.json'; | ||
@@ -174,0 +174,0 @@ const model = await tf.loadFrozenModel(MODEL_PATH, WEIGHTS_PATH); |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
2707873
16928