@tensorflow/tfjs-converter
Advanced tools
Comparing version 0.5.7 to 0.5.8
@@ -22,2 +22,4 @@ import { NamedTensorMap, NamedTensorsMap, TensorInfo } from '../data/types'; | ||
private executeWithControlFlow(inputs, context); | ||
private processStack(stack, context, tensorMap, added); | ||
private processChildNodes(node, stack, context, tensorMap, added); | ||
private findOutputs(tensorMap, context, outputs?); | ||
@@ -24,0 +26,0 @@ dispose(): void; |
@@ -199,5 +199,5 @@ "use strict"; | ||
return __awaiter(this, void 0, void 0, function () { | ||
var stack, tensorMap, added, item, nodeName, tensors, _a, _b; | ||
return __generator(this, function (_c) { | ||
switch (_c.label) { | ||
var stack, tensorMap, added, promises; | ||
return __generator(this, function (_a) { | ||
switch (_a.label) { | ||
case 0: | ||
@@ -209,40 +209,9 @@ stack = this.graph.inputs.map(function (node) { | ||
added = {}; | ||
_c.label = 1; | ||
_a.label = 1; | ||
case 1: | ||
if (!(stack.length > 0)) return [3, 3]; | ||
item = stack.pop(); | ||
context.currentContext = item.contexts; | ||
nodeName = ''; | ||
if (item.node.op === 'enter' && | ||
utils_1.getParamValue('isConstant', item.node, tensorMap, context)) { | ||
nodeName = utils_1.getNodeNameAndIndex(item.node.name, context)[0]; | ||
} | ||
tensors = operation_executor_1.executeOp(item.node, tensorMap, context); | ||
if (!nodeName) { | ||
nodeName = utils_1.getNodeNameAndIndex(item.node.name, context)[0]; | ||
} | ||
_a = tensorMap; | ||
_b = nodeName; | ||
return [4, tensors]; | ||
promises = this.processStack(stack, context, tensorMap, added); | ||
return [4, Promise.all(promises)]; | ||
case 2: | ||
_a[_b] = _c.sent(); | ||
item.node.children.forEach(function (childNode) { | ||
var nodeName = utils_1.getNodeNameAndIndex(childNode.name, context)[0]; | ||
if (!added[nodeName]) { | ||
if (childNode.op === 'merge') { | ||
if (childNode.inputNames.some(function (name) { | ||
return !!utils_1.getTensor(name, tensorMap, context); | ||
})) { | ||
added[nodeName] = true; | ||
stack.push({ contexts: context.currentContext, node: childNode }); | ||
} | ||
} | ||
else if (childNode.inputNames.every(function (name) { | ||
return !!utils_1.getTensor(name, tensorMap, context); | ||
})) { | ||
added[nodeName] = true; | ||
stack.push({ contexts: context.currentContext, node: childNode }); | ||
} | ||
} | ||
}); | ||
_a.sent(); | ||
return [3, 1]; | ||
@@ -254,2 +223,58 @@ case 3: return [2, tensorMap]; | ||
}; | ||
GraphExecutor.prototype.processStack = function (stack, context, tensorMap, added) { | ||
var _this = this; | ||
var promises = []; | ||
var _loop_1 = function () { | ||
var item = stack.pop(); | ||
context.currentContext = item.contexts; | ||
var nodeName = ''; | ||
if (item.node.op === 'enter' && | ||
utils_1.getParamValue('isConstant', item.node, tensorMap, context)) { | ||
nodeName = utils_1.getNodeNameAndIndex(item.node.name, context)[0]; | ||
} | ||
var tensors = operation_executor_1.executeOp(item.node, tensorMap, context); | ||
if (!nodeName) { | ||
nodeName = utils_1.getNodeNameAndIndex(item.node.name, context)[0]; | ||
} | ||
var currentContext = context.currentContext; | ||
if (tensors instanceof Promise) { | ||
promises.push(tensors.then(function (t) { | ||
tensorMap[nodeName] = t; | ||
context.currentContext = currentContext; | ||
_this.processChildNodes(item.node, stack, context, tensorMap, added); | ||
return t; | ||
})); | ||
} | ||
else { | ||
tensorMap[nodeName] = tensors; | ||
this_1.processChildNodes(item.node, stack, context, tensorMap, added); | ||
} | ||
}; | ||
var this_1 = this; | ||
while (stack.length > 0) { | ||
_loop_1(); | ||
} | ||
return promises; | ||
}; | ||
GraphExecutor.prototype.processChildNodes = function (node, stack, context, tensorMap, added) { | ||
node.children.forEach(function (childNode) { | ||
var nodeName = utils_1.getNodeNameAndIndex(childNode.name, context)[0]; | ||
if (!added[nodeName]) { | ||
if (childNode.op === 'merge') { | ||
if (childNode.inputNames.some(function (name) { | ||
return !!utils_1.getTensor(name, tensorMap, context); | ||
})) { | ||
added[nodeName] = true; | ||
stack.push({ contexts: context.currentContext, node: childNode }); | ||
} | ||
} | ||
else if (childNode.inputNames.every(function (name) { | ||
return !!utils_1.getTensor(name, tensorMap, context); | ||
})) { | ||
added[nodeName] = true; | ||
stack.push({ contexts: context.currentContext, node: childNode }); | ||
} | ||
} | ||
}); | ||
}; | ||
GraphExecutor.prototype.findOutputs = function (tensorMap, context, outputs) { | ||
@@ -256,0 +281,0 @@ if (outputs && !(outputs instanceof Array)) { |
@@ -32,2 +32,4 @@ import { DataType, Tensor } from '@tensorflow/tfjs-core'; | ||
split(length: number[], tensor: Tensor): void; | ||
private assertShapesMatch(shapeA, shapeB, errorMessagePrefix?); | ||
private arraysEqual(n1, n2); | ||
} |
@@ -68,3 +68,3 @@ "use strict"; | ||
} | ||
tfjs_core_1.util.assertShapesMatch(this.elementShape, tensor.shape, "TensorArray " + this.name + ": Could not write to TensorArray index " + index + "."); | ||
this.assertShapesMatch(this.elementShape, tensor.shape, "TensorArray " + this.name + ": Could not write to TensorArray index " + index + "."); | ||
if (t && t.read) { | ||
@@ -102,3 +102,3 @@ throw new Error("TensorArray " + this.name + ": Could not write to TensorArray index " + index + ", because it has already been read."); | ||
var tensors = this.readMany(indices); | ||
tfjs_core_1.util.assertShapesMatch(this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: '); | ||
this.assertShapesMatch(this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: '); | ||
return tfjs_core_1.stack(tensors, 0); | ||
@@ -118,3 +118,3 @@ }; | ||
var tensors = this.readMany(indices); | ||
tfjs_core_1.util.assertShapesMatch(this.elementShape, tensors[0].shape, "TensorArray shape mismatch: tensor array shape (" + this.elementShape + ") vs first tensor shape (" + tensors[0].shape + ")"); | ||
this.assertShapesMatch(this.elementShape, tensors[0].shape, "TensorArray shape mismatch: tensor array shape (" + this.elementShape + ") vs first tensor shape (" + tensors[0].shape + ")"); | ||
return tfjs_core_1.concat(tensors, 0); | ||
@@ -170,2 +170,17 @@ }; | ||
}; | ||
TensorArray.prototype.assertShapesMatch = function (shapeA, shapeB, errorMessagePrefix) { | ||
if (errorMessagePrefix === void 0) { errorMessagePrefix = ''; } | ||
tfjs_core_1.util.assert(this.arraysEqual(shapeA, shapeB), errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match")); | ||
}; | ||
TensorArray.prototype.arraysEqual = function (n1, n2) { | ||
if (n1.length !== n2.length) { | ||
return false; | ||
} | ||
for (var i = 0; i < n1.length; i++) { | ||
if (n1[i] !== -1 && n2[i] !== -1 && n1[i] !== n2[i]) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
}; | ||
TensorArray.nextId = 0; | ||
@@ -172,0 +187,0 @@ return TensorArray; |
@@ -1,2 +0,2 @@ | ||
export declare const json: ({ | ||
export declare const json: { | ||
'tfOpName': string; | ||
@@ -27,25 +27,2 @@ 'dlOpName': string; | ||
})[]; | ||
} | { | ||
'tfOpName': string; | ||
'dlOpName': string; | ||
'category': string; | ||
'params': ({ | ||
'tfInputIndex': number; | ||
'dlParamName': string; | ||
'type': string; | ||
tfParamName?: undefined; | ||
notSupported?: undefined; | ||
} | { | ||
'tfParamName': string; | ||
'dlParamName': string; | ||
'type': string; | ||
tfInputIndex?: undefined; | ||
notSupported?: undefined; | ||
} | { | ||
'tfParamName': string; | ||
'dlParamName': string; | ||
'type': string; | ||
'notSupported': boolean; | ||
tfInputIndex?: undefined; | ||
})[]; | ||
})[]; | ||
}[]; |
@@ -36,3 +36,3 @@ "use strict"; | ||
{ 'tfInputIndex': 0, 'dlParamName': 'x', 'type': 'tensor' }, | ||
{ 'tfParamName': 'perm', 'dlParamName': 'perm', 'type': 'number[]' }, { | ||
{ 'tfInputIndex': 1, 'dlParamName': 'perm', 'type': 'number[]' }, { | ||
'tfParamName': 'T', | ||
@@ -39,0 +39,0 @@ 'dlParamName': 'dtype', |
@@ -1,2 +0,2 @@ | ||
declare const version = "0.5.7"; | ||
declare const version = "0.5.8"; | ||
export { version }; |
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
var version = '0.5.7'; | ||
var version = '0.5.8'; | ||
exports.version = version; | ||
//# sourceMappingURL=version.js.map |
{ | ||
"name": "@tensorflow/tfjs-converter", | ||
"version": "0.5.7", | ||
"version": "0.5.8", | ||
"description": "Tensorflow model converter for javascript", | ||
@@ -5,0 +5,0 @@ "main": "dist/src/index.js", |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
2669051
16755