@tensorflow/tfjs-converter
Advanced tools
Comparing version 2.0.0 to 2.0.1
@@ -20,2 +20,3 @@ /** | ||
import { TensorArray } from './tensor_array'; | ||
import { FunctionExecutor } from './types'; | ||
export interface ExecutionContextInfo { | ||
@@ -38,2 +39,5 @@ id: number; | ||
readonly tensorArrayMap: TensorArrayMap; | ||
readonly functionMap: { | ||
[key: string]: FunctionExecutor; | ||
}; | ||
private rootContext; | ||
@@ -43,3 +47,5 @@ private contexts; | ||
private _currentContextIds; | ||
constructor(weightMap: NamedTensorsMap, tensorArrayMap: TensorArrayMap); | ||
constructor(weightMap: NamedTensorsMap, tensorArrayMap: TensorArrayMap, functionMap?: { | ||
[key: string]: FunctionExecutor; | ||
}); | ||
private newFrame; | ||
@@ -46,0 +52,0 @@ /** |
@@ -11,5 +11,6 @@ /** | ||
export class ExecutionContext { | ||
constructor(weightMap, tensorArrayMap) { | ||
constructor(weightMap, tensorArrayMap, functionMap = {}) { | ||
this.weightMap = weightMap; | ||
this.tensorArrayMap = tensorArrayMap; | ||
this.functionMap = functionMap; | ||
this.rootContext = { id: 0, frameName: '', iterationId: 0 }; | ||
@@ -16,0 +17,0 @@ this.contexts = [this.rootContext]; |
@@ -18,9 +18,12 @@ /** | ||
import { NamedTensorMap, Tensor } from '@tensorflow/tfjs-core'; | ||
import { ISignatureDef } from '../data/compiled_api'; | ||
import { NamedTensorsMap, TensorInfo } from '../data/types'; | ||
import { Graph } from '../operations/types'; | ||
export declare class GraphExecutor { | ||
import { FunctionExecutor } from './types'; | ||
export declare class GraphExecutor implements FunctionExecutor { | ||
private graph; | ||
private parent?; | ||
private compiledMap; | ||
private _weightMap; | ||
private weightIds; | ||
private _weightIds; | ||
private _signature; | ||
@@ -30,2 +33,8 @@ private _inputs; | ||
private SEPERATOR; | ||
private _functions; | ||
private _functionExecutorMap; | ||
readonly weightIds: number[]; | ||
readonly functionExecutorMap: { | ||
[key: string]: FunctionExecutor; | ||
}; | ||
weightMap: NamedTensorsMap; | ||
@@ -36,3 +45,14 @@ readonly inputs: TensorInfo[]; | ||
readonly outputNodes: string[]; | ||
constructor(graph: Graph); | ||
readonly functions: { | ||
[key: string]: ISignatureDef; | ||
}; | ||
/** | ||
* | ||
* @param graph Graph the model or function graph to be executed. | ||
* @param parent When building function exector you need to set the parent | ||
* executor. Since the weights and function executor maps are set at parant | ||
* level, that function executor can access the function maps and weight maps | ||
* through the parent. | ||
*/ | ||
constructor(graph: Graph, parent?: GraphExecutor); | ||
private getCompilationKey; | ||
@@ -64,4 +84,7 @@ /** | ||
* array. | ||
* @param disableWarning disable the no dynamic ops warning message, default | ||
* to false | ||
*/ | ||
executeAsync(inputs: NamedTensorMap, outputs: string[]): Promise<Tensor[]>; | ||
executeAsync(inputs: NamedTensorMap, outputs: string[], disableWarning?: boolean): Promise<Tensor[]>; | ||
executeFunctionAsync(inputs: Tensor[]): Promise<Tensor[]>; | ||
/** | ||
@@ -72,2 +95,3 @@ * When there are control flow nodes in the graph, the graph execution use | ||
* @param context the execution context object for current execution. | ||
* @param disableWarning disable no async op warning | ||
*/ | ||
@@ -74,0 +98,0 @@ private executeWithControlFlow; |
@@ -23,17 +23,43 @@ /** | ||
export class GraphExecutor { | ||
constructor(graph) { | ||
/** | ||
* | ||
* @param graph Graph the model or function graph to be executed. | ||
* @param parent When building function exector you need to set the parent | ||
* executor. Since the weights and function executor maps are set at parant | ||
* level, that function executor can access the function maps and weight maps | ||
* through the parent. | ||
*/ | ||
constructor(graph, parent) { | ||
this.graph = graph; | ||
this.parent = parent; | ||
this.compiledMap = new Map(); | ||
this._weightMap = {}; | ||
this.SEPERATOR = ','; | ||
this._functions = {}; | ||
this._functionExecutorMap = {}; | ||
this._outputs = graph.outputs; | ||
this._inputs = graph.inputs; | ||
this._signature = graph.signature; | ||
this._functions = graph.functions; | ||
// create sub-graph executors | ||
if (graph.functions != null) { | ||
Object.keys(graph.functions).forEach(name => { | ||
this._functionExecutorMap[name] = | ||
new GraphExecutor(graph.functions[name], this); | ||
}); | ||
} | ||
} | ||
get weightIds() { | ||
return this.parent ? this.parent.weightIds : this._weightIds; | ||
} | ||
get functionExecutorMap() { | ||
return this.parent ? this.parent.functionExecutorMap : | ||
this._functionExecutorMap; | ||
} | ||
get weightMap() { | ||
return this._weightMap; | ||
return this.parent ? this.parent.weightMap : this._weightMap; | ||
} | ||
set weightMap(weightMap) { | ||
const weightIds = Object.keys(weightMap).map(key => weightMap[key].map(tensor => tensor.id)); | ||
this.weightIds = [].concat(...weightIds); | ||
this._weightIds = [].concat(...weightIds); | ||
this._weightMap = weightMap; | ||
@@ -71,4 +97,13 @@ } | ||
get outputNodes() { | ||
return this._outputs.map(node => node.signatureKey || node.name); | ||
return this._outputs.map((node) => { | ||
const name = node.signatureKey || node.name; | ||
return node.defaultOutput ? (`${name}:${node.defaultOutput}`) : name; | ||
}); | ||
} | ||
get functions() { | ||
return Object.keys(this._functions).reduce((map, key) => { | ||
map[key] = this._functions[key].signature; | ||
return map; | ||
}, {}); | ||
} | ||
getCompilationKey(inputs, outputs) { | ||
@@ -128,3 +163,3 @@ const sortedInputs = inputs.map(node => node.name).sort(); | ||
return tidy(() => { | ||
const context = new ExecutionContext(this._weightMap, tensorArrayMap); | ||
const context = new ExecutionContext(this.weightMap, tensorArrayMap, this.functionExecutorMap); | ||
const tensorsMap = Object.assign({}, this.weightMap); | ||
@@ -205,4 +240,6 @@ Object.keys(inputs).forEach(name => { | ||
* array. | ||
* @param disableWarning disable the no dynamic ops warning message, default | ||
* to false | ||
*/ | ||
async executeAsync(inputs, outputs) { | ||
async executeAsync(inputs, outputs, disableWarning = false) { | ||
inputs = this.mapInputs(inputs); | ||
@@ -214,7 +251,7 @@ this.checkInputs(inputs); | ||
const tensorArrayMap = {}; | ||
const context = new ExecutionContext(this._weightMap, tensorArrayMap); | ||
const context = new ExecutionContext(this.weightMap, tensorArrayMap, this.functionExecutorMap); | ||
// Graph with control flow op requires runtime evaluation of the execution | ||
// order, while without control flow the execution order is pre-determined | ||
// in the compile method. | ||
const tensorMap = await this.executeWithControlFlow(inputs, context, outputs); | ||
const tensorMap = await this.executeWithControlFlow(inputs, context, outputs, disableWarning); | ||
const results = outputs.map(name => getTensor(name, tensorMap, context)); | ||
@@ -236,2 +273,9 @@ // dispose all the intermediate tensors | ||
} | ||
async executeFunctionAsync(inputs) { | ||
const mappedInputs = inputs.reduce((map, tensor, index) => { | ||
map[this.inputs[index].name] = tensor; | ||
return map; | ||
}, {}); | ||
return this.executeAsync(mappedInputs, this.outputNodes, true); | ||
} | ||
/** | ||
@@ -242,4 +286,5 @@ * When there are control flow nodes in the graph, the graph execution use | ||
* @param context the execution context object for current execution. | ||
* @param disableWarning disable no async op warning | ||
*/ | ||
async executeWithControlFlow(inputs, context, outputNames) { | ||
async executeWithControlFlow(inputs, context, outputNames, disableWarning) { | ||
const names = Object.keys(inputs); | ||
@@ -266,3 +311,3 @@ const inputNodes = names.map(name => this.graph.nodes[parseNodeName(name)[0]]); | ||
} | ||
if (dynamicNode == null) { | ||
if (dynamicNode == null && !disableWarning) { | ||
console.warn(`This model execution did not contain any nodes with control flow ` + | ||
@@ -269,0 +314,0 @@ `or dynamic output shapes. You can use model.execute() instead.`); |
@@ -22,4 +22,4 @@ /** | ||
/** | ||
* A `tf.GraphModel` is a directed, acyclic graph of built from | ||
* SavedModel GraphDef and allows inference exeuction. | ||
* A `tf.GraphModel` is a directed, acyclic graph built from a | ||
* SavedModel GraphDef and allows inference execution. | ||
* | ||
@@ -61,2 +61,8 @@ * A `tf.GraphModel` can only be created by loading from a model converted from | ||
/** | ||
* Synchronously construct the in memory weight map and | ||
* compile the inference graph. | ||
*/ | ||
/** @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true} */ | ||
loadSync(artifacts: io.ModelArtifacts): boolean; | ||
/** | ||
* Save the configuration and/or weights of the GraphModel. | ||
@@ -63,0 +69,0 @@ * |
@@ -23,4 +23,4 @@ /** | ||
/** | ||
* A `tf.GraphModel` is a directed, acyclic graph of built from | ||
* SavedModel GraphDef and allows inference exeuction. | ||
* A `tf.GraphModel` is a directed, acyclic graph built from a | ||
* SavedModel GraphDef and allows inference execution. | ||
* | ||
@@ -102,3 +102,12 @@ * A `tf.GraphModel` can only be created by loading from a model converted from | ||
} | ||
this.artifacts = await this.handler.load(); | ||
const artifacts = await this.handler.load(); | ||
return this.loadSync(artifacts); | ||
} | ||
/** | ||
* Synchronously construct the in memory weight map and | ||
* compile the inference graph. | ||
*/ | ||
/** @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true} */ | ||
loadSync(artifacts) { | ||
this.artifacts = artifacts; | ||
const graph = this.artifacts.modelTopology; | ||
@@ -105,0 +114,0 @@ let signature = {}; |
@@ -106,3 +106,6 @@ /** | ||
} | ||
const CONTROL_FLOW_OPS = ['Switch', 'Merge', 'Enter', 'Exit', 'NextIteration']; | ||
const CONTROL_FLOW_OPS = [ | ||
'Switch', 'Merge', 'Enter', 'Exit', 'NextIteration', 'StatelessIf', | ||
'StatelessWhile' | ||
]; | ||
const DYNAMIC_SHAPE_OPS = [ | ||
@@ -109,0 +112,0 @@ 'NonMaxSuppressionV2', 'NonMaxSuppressionV3', 'NonMaxSuppressionV5', 'Where' |
@@ -93,9 +93,2 @@ /** | ||
split(length: number[], tensor: Tensor): void; | ||
/** | ||
* This differs from util.assertShapesMatch in that it allows values of | ||
* negative one, an undefined size of a dimensinon, in a shape to match | ||
* anything. | ||
*/ | ||
private assertShapesMatchAllowUndefinedSize; | ||
private shapesEqualAllowUndefinedSize; | ||
} |
@@ -17,3 +17,4 @@ /** | ||
*/ | ||
import { concat, slice, stack, tensor, tidy, unstack, util } from '@tensorflow/tfjs-core'; | ||
import { concat, slice, stack, tensor, tidy, unstack } from '@tensorflow/tfjs-core'; | ||
import { assertShapesMatchAllowUndefinedSize } from './tensor_utils'; | ||
/** | ||
@@ -100,3 +101,3 @@ * The TensorArray object keeps an array of Tensors. It | ||
} | ||
this.assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, `TensorArray ${this.name}: Could not write to TensorArray index ${index}.`); | ||
assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, `TensorArray ${this.name}: Could not write to TensorArray index ${index}.`); | ||
if (t && t.read) { | ||
@@ -149,3 +150,3 @@ throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been read.`); | ||
const tensors = this.readMany(indices); | ||
this.assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: '); | ||
assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: '); | ||
return stack(tensors, 0); | ||
@@ -169,3 +170,3 @@ } | ||
const tensors = this.readMany(indices); | ||
this.assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, `TensorArray shape mismatch: tensor array shape (${this.elementShape}) vs first tensor shape (${tensors[0].shape})`); | ||
assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, `TensorArray shape mismatch: tensor array shape (${this.elementShape}) vs first tensor shape (${tensors[0].shape})`); | ||
return concat(tensors, 0); | ||
@@ -234,23 +235,4 @@ } | ||
} | ||
/** | ||
* This differs from util.assertShapesMatch in that it allows values of | ||
* negative one, an undefined size of a dimensinon, in a shape to match | ||
* anything. | ||
*/ | ||
assertShapesMatchAllowUndefinedSize(shapeA, shapeB, errorMessagePrefix = '') { | ||
util.assert(this.shapesEqualAllowUndefinedSize(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`); | ||
} | ||
shapesEqualAllowUndefinedSize(n1, n2) { | ||
if (n1.length !== n2.length) { | ||
return false; | ||
} | ||
for (let i = 0; i < n1.length; i++) { | ||
if (n1[i] !== -1 && n2[i] !== -1 && n1[i] !== n2[i]) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
} | ||
TensorArray.nextId = 0; | ||
//# sourceMappingURL=tensor_array.js.map |
@@ -22,2 +22,32 @@ /** | ||
switch (node.op) { | ||
case 'If': | ||
case 'StatelessIf': { | ||
const thenFunc = getParamValue('thenBranch', node, tensorMap, context); | ||
const elseFunc = getParamValue('elseBranch', node, tensorMap, context); | ||
const cond = getParamValue('cond', node, tensorMap, context); | ||
const args = getParamValue('args', node, tensorMap, context); | ||
const condValue = await cond.data(); | ||
if (condValue[0]) { | ||
return context.functionMap[thenFunc].executeFunctionAsync(args); | ||
} | ||
else { | ||
return context.functionMap[elseFunc].executeFunctionAsync(args); | ||
} | ||
} | ||
case 'While': | ||
case 'StatelessWhile': { | ||
const bodyFunc = getParamValue('body', node, tensorMap, context); | ||
const condFunc = getParamValue('cond', node, tensorMap, context); | ||
const args = getParamValue('args', node, tensorMap, context); | ||
const condTensor = (await context.functionMap[condFunc].executeFunctionAsync(args))[0]; | ||
let condValue = await condTensor.data(); | ||
let result = args; | ||
while (condValue[0]) { | ||
result = | ||
await context.functionMap[bodyFunc].executeFunctionAsync(result); | ||
const condTensor = (await context.functionMap[condFunc].executeFunctionAsync(result))[0]; | ||
condValue = await condTensor.data(); | ||
} | ||
return result; | ||
} | ||
case 'LoopCond': | ||
@@ -24,0 +54,0 @@ return [ |
@@ -64,2 +64,8 @@ /** | ||
} | ||
case 'Cumsum': { | ||
const axis = getParamValue('axis', node, tensorMap, context); | ||
const exclusive = getParamValue('exclusive', node, tensorMap, context); | ||
const reverse = getParamValue('reverse', node, tensorMap, context); | ||
return [tfc.cumsum(getParamValue('x', node, tensorMap, context), axis, exclusive, reverse)]; | ||
} | ||
default: | ||
@@ -66,0 +72,0 @@ throw TypeError(`Node type ${node.op} is not implemented`); |
@@ -54,2 +54,5 @@ /** | ||
} | ||
case 'BroadcastTo': { | ||
return [tfc.broadcastTo(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))]; | ||
} | ||
default: | ||
@@ -56,0 +59,0 @@ throw TypeError(`Node type ${node.op} is not implemented`); |
@@ -79,8 +79,8 @@ /** | ||
export function parseNodeName(name) { | ||
const index = name.lastIndexOf(':'); | ||
if (index === -1) { | ||
const parts = name.split(':'); | ||
if (parts.length === 1) { | ||
return [name, 0]; | ||
} | ||
const nodeName = name.substring(0, index); | ||
return [nodeName, Number(name.substring(index + 1))]; | ||
const nodeName = parts[0]; | ||
return [nodeName, Number(parts[parts.length - 1])]; | ||
} | ||
@@ -87,0 +87,0 @@ export function split(arr, size) { |
@@ -178,4 +178,50 @@ /** | ||
'inputs': [{ 'start': 0, 'name': 'tensorArrayId', 'type': 'number' }] | ||
}, | ||
{ | ||
'tfOpName': 'StatelessIf', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'cond', 'type': 'tensor' }, | ||
{ 'start': 1, 'end': 0, 'name': 'args', 'type': 'tensors' } | ||
], | ||
'attrs': [ | ||
{ 'tfName': 'then_branch', 'name': 'thenBranch', 'type': 'func' }, | ||
{ 'tfName': 'else_branch', 'name': 'elseBranch', 'type': 'func' } | ||
] | ||
}, | ||
{ | ||
'tfOpName': 'If', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'cond', 'type': 'tensor' }, | ||
{ 'start': 1, 'end': 0, 'name': 'args', 'type': 'tensors' } | ||
], | ||
'attrs': [ | ||
{ 'tfName': 'then_branch', 'name': 'thenBranch', 'type': 'func' }, | ||
{ 'tfName': 'else_branch', 'name': 'elseBranch', 'type': 'func' } | ||
] | ||
}, | ||
{ | ||
'tfOpName': 'StatelessWhile', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'end': 0, 'name': 'args', 'type': 'tensors' }, | ||
], | ||
'attrs': [ | ||
{ 'tfName': 'cond', 'name': 'cond', 'type': 'func' }, | ||
{ 'tfName': 'body', 'name': 'body', 'type': 'func' } | ||
] | ||
}, | ||
{ | ||
'tfOpName': 'While', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'end': 0, 'name': 'args', 'type': 'tensors' }, | ||
], | ||
'attrs': [ | ||
{ 'tfName': 'cond', 'name': 'cond', 'type': 'func' }, | ||
{ 'tfName': 'body', 'name': 'body', 'type': 'func' } | ||
] | ||
} | ||
]; | ||
//# sourceMappingURL=control.js.map |
@@ -147,2 +147,8 @@ /** | ||
}, | ||
{ | ||
'tfName': 'explicit_paddings', | ||
'name': 'explicitPaddings', | ||
'type': 'number[]', | ||
'defaultValue': [] | ||
}, | ||
{ 'tfName': 'dilations', 'name': 'dilations', 'type': 'number[]' } | ||
@@ -212,3 +218,4 @@ ] | ||
{ 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' }, | ||
{ 'tfName': 'padding', 'name': 'pad', 'type': 'string' }, { | ||
{ 'tfName': 'padding', 'name': 'pad', 'type': 'string' }, | ||
{ | ||
'tfName': 'data_format', | ||
@@ -218,3 +225,9 @@ 'name': 'dataFormat', | ||
'notSupported': true | ||
} | ||
}, | ||
{ | ||
'tfName': 'explicit_paddings', | ||
'name': 'explicitPaddings', | ||
'type': 'number[]', | ||
'defaultValue': [] | ||
}, | ||
] | ||
@@ -237,2 +250,8 @@ }, | ||
}, | ||
{ | ||
'tfName': 'explicit_paddings', | ||
'name': 'explicitPaddings', | ||
'type': 'number[]', | ||
'defaultValue': [] | ||
}, | ||
{ 'tfName': 'dilations', 'name': 'dilations', 'type': 'number[]' } | ||
@@ -256,2 +275,8 @@ ] | ||
}, | ||
{ | ||
'tfName': 'explicit_paddings', | ||
'name': 'explicitPaddings', | ||
'type': 'number[]', | ||
'defaultValue': [] | ||
}, | ||
{ 'tfName': 'dilations', 'name': 'dilations', 'type': 'number[]' } | ||
@@ -258,0 +283,0 @@ ] |
@@ -96,4 +96,16 @@ /** | ||
'attrs': [{ 'tfName': 'keep_dims', 'name': 'keepDims', 'type': 'bool' }] | ||
}, | ||
{ | ||
'tfOpName': 'Cumsum', | ||
'category': 'reduction', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'x', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'axis', 'type': 'number' }, | ||
], | ||
'attrs': [ | ||
{ 'tfName': 'exclusive', 'name': 'exclusive', 'type': 'bool' }, | ||
{ 'tfName': 'reverse', 'name': 'reverse', 'type': 'bool' } | ||
] | ||
} | ||
]; | ||
//# sourceMappingURL=reduction.js.map |
@@ -118,4 +118,13 @@ /** | ||
] | ||
}, | ||
{ | ||
'tfOpName': 'BroadcastTo', | ||
'category': 'transformation', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'x', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'shape', 'type': 'number[]' }, | ||
], | ||
'attrs': [] | ||
} | ||
]; | ||
//# sourceMappingURL=transformation.js.map |
@@ -71,3 +71,3 @@ /** | ||
'number', 'string', 'number[]', 'bool', 'shape', 'tensor', 'tensors', | ||
'dtype', 'string[]' | ||
'dtype', 'string[]', 'func', 'dtype[]' | ||
] | ||
@@ -74,0 +74,0 @@ }, |
@@ -28,2 +28,5 @@ /** | ||
private mapNode; | ||
private mapFunction; | ||
private mapArgsToSignature; | ||
private mapArgToTensorInfo; | ||
} | ||
@@ -42,2 +45,5 @@ export declare function decodeBase64(text: string): string; | ||
export declare function parseDtypeParam(value: string | tensorflow.DataType): DataType; | ||
export declare function getFuncParam(attrs: { | ||
[key: string]: tensorflow.IAttrValue; | ||
}, name: string, def: string): string; | ||
export declare function getDtypeParam(attrs: { | ||
@@ -44,0 +50,0 @@ [key: string]: tensorflow.IAttrValue; |
@@ -121,3 +121,18 @@ /** | ||
} | ||
return { nodes, inputs, outputs, weights, placeholders, signature }; | ||
let functions = {}; | ||
if (graph.library != null && graph.library.function != null) { | ||
functions = graph.library.function.reduce((functions, func) => { | ||
functions[func.signature.name] = this.mapFunction(func); | ||
return functions; | ||
}, {}); | ||
} | ||
return { | ||
nodes, | ||
inputs, | ||
outputs, | ||
weights, | ||
placeholders, | ||
signature, | ||
functions | ||
}; | ||
} | ||
@@ -227,2 +242,8 @@ mapSignatureEntries(entries) { | ||
break; | ||
case 'func': | ||
value = getFuncParam(node.attr, param.tfName, param.defaultValue); | ||
if (value === undefined && !!param.tfDeprecatedName) { | ||
value = getFuncParam(node.attr, param.tfDeprecatedName, param.defaultValue); | ||
} | ||
break; | ||
case 'tensor': | ||
@@ -240,2 +261,76 @@ case 'tensors': | ||
} | ||
// map the TFunctionDef to TFJS graph object | ||
mapFunction(functionDef) { | ||
const tfNodes = functionDef.nodeDef; | ||
const placeholders = []; | ||
const weights = []; | ||
let nodes = {}; | ||
if (tfNodes != null) { | ||
nodes = tfNodes.reduce((map, node) => { | ||
map[node.name] = this.mapNode(node); | ||
if (node.op === 'Const') { | ||
weights.push(map[node.name]); | ||
} | ||
return map; | ||
}, {}); | ||
} | ||
const inputs = []; | ||
const outputs = []; | ||
functionDef.signature.inputArg.forEach(arg => { | ||
const [nodeName,] = getNodeNameAndIndex(arg.name); | ||
const node = { | ||
name: nodeName, | ||
op: 'Placeholder', | ||
inputs: [], | ||
inputNames: [], | ||
category: 'graph', | ||
inputParams: {}, | ||
attrParams: { dtype: { value: parseDtypeParam(arg.type), type: 'dtype' } }, | ||
children: [] | ||
}; | ||
node.signatureKey = arg.name; | ||
inputs.push(node); | ||
nodes[nodeName] = node; | ||
}); | ||
const allNodes = Object.keys(nodes); | ||
allNodes.forEach(key => { | ||
const node = nodes[key]; | ||
node.inputNames.forEach(name => { | ||
const [nodeName,] = getNodeNameAndIndex(name); | ||
node.inputs.push(nodes[nodeName]); | ||
nodes[nodeName].children.push(node); | ||
}); | ||
}); | ||
const returnNodeMap = functionDef.ret; | ||
functionDef.signature.outputArg.forEach(output => { | ||
const [nodeName, index] = getNodeNameAndIndex(returnNodeMap[output.name]); | ||
const node = nodes[nodeName]; | ||
if (node != null) { | ||
node.defaultOutput = index; | ||
outputs.push(node); | ||
} | ||
}); | ||
const signature = this.mapArgsToSignature(functionDef); | ||
return { nodes, inputs, outputs, weights, placeholders, signature }; | ||
} | ||
mapArgsToSignature(functionDef) { | ||
return { | ||
methodName: functionDef.signature.name, | ||
inputs: functionDef.signature.inputArg.reduce((map, arg) => { | ||
map[arg.name] = this.mapArgToTensorInfo(arg); | ||
return map; | ||
}, {}), | ||
outputs: functionDef.signature.outputArg.reduce((map, arg) => { | ||
map[arg.name] = this.mapArgToTensorInfo(arg, functionDef.ret); | ||
return map; | ||
}, {}), | ||
}; | ||
} | ||
mapArgToTensorInfo(arg, nameMap) { | ||
let name = arg.name; | ||
if (nameMap != null) { | ||
name = nameMap[name]; | ||
} | ||
return { name, dtype: arg.type }; | ||
} | ||
} | ||
@@ -300,2 +395,9 @@ export function decodeBase64(text) { | ||
} | ||
export function getFuncParam(attrs, name, def) { | ||
const param = attrs[name]; | ||
if (param && param.func) { | ||
return param.func.name; | ||
} | ||
return def; | ||
} | ||
export function getDtypeParam(attrs, name, def) { | ||
@@ -302,0 +404,0 @@ const param = attrs[name]; |
@@ -21,3 +21,3 @@ /** | ||
import { ExecutionContext } from '../executor/execution_context'; | ||
export declare type ParamType = 'number' | 'string' | 'string[]' | 'number[]' | 'bool' | 'bool[]' | 'shape' | 'shape[]' | 'tensor' | 'tensors' | 'dtype' | 'dtype[]'; | ||
export declare type ParamType = 'number' | 'string' | 'string[]' | 'number[]' | 'bool' | 'bool[]' | 'shape' | 'shape[]' | 'tensor' | 'tensors' | 'dtype' | 'dtype[]' | 'func'; | ||
export declare type Category = 'arithmetic' | 'basic_math' | 'control' | 'convolution' | 'custom' | 'dynamic' | 'evaluation' | 'image' | 'creation' | 'graph' | 'logical' | 'matrices' | 'normalization' | 'reduction' | 'slice_join' | 'spectral' | 'transformation'; | ||
@@ -68,2 +68,3 @@ export declare interface ParamMapper { | ||
}; | ||
defaultOutput?: number; | ||
} | ||
@@ -79,2 +80,5 @@ export declare interface Graph { | ||
signature?: tensorflow.ISignatureDef; | ||
functions?: { | ||
[key: string]: Graph; | ||
}; | ||
} | ||
@@ -81,0 +85,0 @@ export declare type ValueType = string | string[] | number | number[] | number[][] | boolean | boolean[] | Tensor | Tensor[]; |
/** @license See the LICENSE file. */ | ||
declare const version = "2.0.0"; | ||
declare const version = "2.0.1"; | ||
export { version }; |
/** @license See the LICENSE file. */ | ||
// This code is auto-generated, do not modify this file! | ||
const version = '2.0.0'; | ||
const version = '2.0.1'; | ||
export { version }; | ||
//# sourceMappingURL=version.js.map |
{ | ||
"name": "@tensorflow/tfjs-converter", | ||
"version": "2.0.0", | ||
"version": "2.0.1", | ||
"description": "Tensorflow model converter for javascript", | ||
@@ -18,3 +18,3 @@ "main": "dist/tf-converter.node.js", | ||
"peerDependencies": { | ||
"@tensorflow/tfjs-core": "2.0.0" | ||
"@tensorflow/tfjs-core": "2.0.1" | ||
}, | ||
@@ -25,4 +25,4 @@ "devDependencies": { | ||
"@rollup/plugin-typescript": "^3.0.0", | ||
"@tensorflow/tfjs-backend-cpu": "2.0.0", | ||
"@tensorflow/tfjs-core": "2.0.0", | ||
"@tensorflow/tfjs-backend-cpu": "2.0.1", | ||
"@tensorflow/tfjs-core": "2.0.1", | ||
"@types/deep-equal": "^1.0.1", | ||
@@ -29,0 +29,0 @@ "@types/jasmine": "~2.8.6", |
@@ -157,5 +157,9 @@ # Getting started | ||
|`--strip_debug_ops` | Strips out TensorFlow debug operations `Print`, `Assert`, `CheckNumerics`. Defaults to `True`.| | ||
|`--quantization_bytes` | How many bytes to optionally quantize/compress the weights to. Valid values are 1 and 2. which will quantize int32 and float32 to 1 or 2 bytes respectively. The default (unquantized) size is 4 bytes.| | ||
|`--quantization_bytes` | (Deprecated) How many bytes to optionally quantize/compress the weights to. Valid values are 1 and 2. which will quantize int32 and float32 to 1 or 2 bytes respectively. The default (unquantized) size is 4 bytes.| | ||
|`--quantize_float16` | Comma separated list of node names to apply float16 quantization. You can also use wildcard symbol (*) to apply quantization to multiple nodes (e.g., conv/*/weights). When the flag is provided without any nodes the default behavior will match all nodes. | | ||
|`--quantize_uint8` | Comma separated list of node names to apply 1-byte affine quantization. You can also use wildcard symbol (*) to apply quantization to multiple nodes (e.g., conv/*/weights). When the flag is provided without any nodes the default behavior will match all nodes. | | ||
|`--quantize_uint16` | Comma separated list of node names to apply 2-byte affine quantization. You can also use wildcard symbol (*) to apply quantization to multiple nodes (e.g., conv/*/weights). When the flag is provided without any nodes the default behavior will match all nodes. | | ||
|`--weight_shard_size_bytes` | Shard size (in bytes) of the weight files. Only supported when `output_format` is `tfjs_layers_model` or `tfjs_graph_model`. Default size is 4 MB (4194304 bytes).| | ||
|<nobr>`--output_node_names`</nobr>| Only applicable to Frozen Model. The names of the output nodes, separated by commas.| | ||
|<nobr>`--control_flow_v2`</nobr>| Only applicable to TF 2.x Saved Model. This flag improve performance on models with control flow ops, default to False.| | ||
@@ -220,3 +224,3 @@ __Note: If you want to convert TensorFlow session bundle, you can install older versions of the tensorflowjs pip package, i.e. `pip install tensorflowjs==0.8.6`.__ | ||
--output_format tfjs_layers_model \ | ||
--quantization_bytes 2 \ | ||
--quantize_uint16 \ | ||
original_model/model.json | ||
@@ -385,10 +389,14 @@ quantized_model/ | ||
Yes, you can use the --quantization_bytes option to compress int32/float32 to 1 | ||
or 2 bytes. Here is | ||
an example of 8-bit quantization: | ||
Yes, you can use the --quantize_{float16, uint8, uint16} flags to compress | ||
weights with 1 byte integer quantization (`uint8`) or 2 byte integer | ||
(`uint16`)/float (`float16`) quantization. | ||
Quantizing to float16 may provide better accuracy over | ||
2 byte affine integer scaling (`uint16`). 1-byte affine quantization, | ||
i.e., `uint8` provides a 4x size reduction at the cost of accuracy. | ||
For example, we can quantize our MobileNet model using float16 quantization: | ||
``` | ||
tensorflowjs_converter \ | ||
tensorflowjs_converter | ||
--quantize_float16 \ | ||
--input_format=tf_hub \ | ||
--quantization_bytes=1 | ||
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \ | ||
@@ -398,2 +406,13 @@ /mobilenet/web_model | ||
You can also quantize specific weights as well as weight groupings using | ||
a wildcard replacement. For example, | ||
``` | ||
tensorflowjs_converter | ||
--quantize_float16="conv/*/weights" | ||
``` | ||
which will quantize all weights that match the pattern conv/*/weights. | ||
This will exclude biases and any weights that don't begin with conv/. | ||
This can be a powerful tool to reduce model size while trying to maximize | ||
performance. | ||
__5. Why is the predict() method for inference so much slower on the first call than the subsequent calls?__ | ||
@@ -400,0 +419,0 @@ |
@@ -22,2 +22,3 @@ /** | ||
import {TensorArray} from './tensor_array'; | ||
import {FunctionExecutor} from './types'; | ||
@@ -48,3 +49,4 @@ export interface ExecutionContextInfo { | ||
public readonly weightMap: NamedTensorsMap, | ||
public readonly tensorArrayMap: TensorArrayMap) { | ||
public readonly tensorArrayMap: TensorArrayMap, | ||
public readonly functionMap: {[key: string]: FunctionExecutor} = {}) { | ||
this.generateCurrentContextIds(); | ||
@@ -51,0 +53,0 @@ } |
@@ -28,2 +28,3 @@ /** | ||
import {getExecutionSubgraph, getNodesInTopologicalOrder, isControlFlow} from './model_analysis'; | ||
import {FunctionExecutor} from './types'; | ||
@@ -35,6 +36,6 @@ interface NodeWithContexts { | ||
export class GraphExecutor { | ||
export class GraphExecutor implements FunctionExecutor { | ||
private compiledMap: Map<string, Node[]> = new Map(); | ||
private _weightMap: NamedTensorsMap = {}; | ||
private weightIds: number[]; | ||
private _weightIds: number[]; | ||
private _signature: ISignatureDef; | ||
@@ -44,9 +45,22 @@ private _inputs: Node[]; | ||
private SEPERATOR = ','; | ||
private _functions: {[key: string]: Graph} = {}; | ||
private _functionExecutorMap: {[key: string]: FunctionExecutor} = {}; | ||
get weightIds(): number[] { | ||
return this.parent ? this.parent.weightIds : this._weightIds; | ||
} | ||
get functionExecutorMap(): {[key: string]: FunctionExecutor} { | ||
return this.parent ? this.parent.functionExecutorMap : | ||
this._functionExecutorMap; | ||
} | ||
get weightMap(): NamedTensorsMap { | ||
return this._weightMap; | ||
return this.parent ? this.parent.weightMap : this._weightMap; | ||
} | ||
set weightMap(weightMap: NamedTensorsMap) { | ||
const weightIds = Object.keys(weightMap).map( | ||
key => weightMap[key].map(tensor => tensor.id)); | ||
this.weightIds = [].concat(...weightIds); | ||
this._weightIds = [].concat(...weightIds); | ||
this._weightMap = weightMap; | ||
@@ -88,9 +102,35 @@ } | ||
get outputNodes(): string[] { | ||
return this._outputs.map(node => node.signatureKey || node.name); | ||
return this._outputs.map((node) => { | ||
const name = node.signatureKey || node.name; | ||
return node.defaultOutput ? (`${name}:${node.defaultOutput}`) : name; | ||
}); | ||
} | ||
constructor(private graph: Graph) { | ||
get functions(): {[key: string]: ISignatureDef} { | ||
return Object.keys(this._functions).reduce((map, key) => { | ||
map[key] = this._functions[key].signature; | ||
return map; | ||
}, {} as {[key: string]: ISignatureDef}); | ||
} | ||
/** | ||
* | ||
* @param graph Graph the model or function graph to be executed. | ||
* @param parent When building function exector you need to set the parent | ||
* executor. Since the weights and function executor maps are set at parant | ||
* level, that function executor can access the function maps and weight maps | ||
* through the parent. | ||
*/ | ||
constructor(private graph: Graph, private parent?: GraphExecutor) { | ||
this._outputs = graph.outputs; | ||
this._inputs = graph.inputs; | ||
this._signature = graph.signature; | ||
this._functions = graph.functions; | ||
// create sub-graph executors | ||
if (graph.functions != null) { | ||
Object.keys(graph.functions).forEach(name => { | ||
this._functionExecutorMap[name] = | ||
new GraphExecutor(graph.functions[name], this); | ||
}); | ||
} | ||
} | ||
@@ -161,3 +201,4 @@ | ||
return tidy(() => { | ||
const context = new ExecutionContext(this._weightMap, tensorArrayMap); | ||
const context = new ExecutionContext( | ||
this.weightMap, tensorArrayMap, this.functionExecutorMap); | ||
const tensorsMap: NamedTensorsMap = {...this.weightMap}; | ||
@@ -249,5 +290,8 @@ Object.keys(inputs).forEach(name => { | ||
* array. | ||
* @param disableWarning disable the no dynamic ops warning message, default | ||
* to false | ||
*/ | ||
async executeAsync(inputs: NamedTensorMap, outputs: string[]): | ||
Promise<Tensor[]> { | ||
async executeAsync( | ||
inputs: NamedTensorMap, outputs: string[], | ||
disableWarning = false): Promise<Tensor[]> { | ||
inputs = this.mapInputs(inputs); | ||
@@ -259,8 +303,9 @@ this.checkInputs(inputs); | ||
const tensorArrayMap: TensorArrayMap = {}; | ||
const context = new ExecutionContext(this._weightMap, tensorArrayMap); | ||
const context = new ExecutionContext( | ||
this.weightMap, tensorArrayMap, this.functionExecutorMap); | ||
// Graph with control flow op requires runtime evaluation of the execution | ||
// order, while without control flow the execution order is pre-determined | ||
// in the compile method. | ||
const tensorMap = | ||
await this.executeWithControlFlow(inputs, context, outputs); | ||
const tensorMap = await this.executeWithControlFlow( | ||
inputs, context, outputs, disableWarning); | ||
const results = outputs.map(name => getTensor(name, tensorMap, context)); | ||
@@ -285,2 +330,10 @@ | ||
async executeFunctionAsync(inputs: Tensor[]): Promise<Tensor[]> { | ||
const mappedInputs = inputs.reduce((map, tensor, index) => { | ||
map[this.inputs[index].name] = tensor; | ||
return map; | ||
}, {} as NamedTensorMap); | ||
return this.executeAsync(mappedInputs, this.outputNodes, true); | ||
} | ||
/** | ||
@@ -291,6 +344,7 @@ * When there are control flow nodes in the graph, the graph execution use | ||
* @param context the execution context object for current execution. | ||
* @param disableWarning disable no async op warning | ||
*/ | ||
private async executeWithControlFlow( | ||
inputs: NamedTensorMap, context: ExecutionContext, | ||
outputNames: string[]): Promise<NamedTensorsMap> { | ||
inputs: NamedTensorMap, context: ExecutionContext, outputNames: string[], | ||
disableWarning: boolean): Promise<NamedTensorsMap> { | ||
const names = Object.keys(inputs); | ||
@@ -324,3 +378,3 @@ const inputNodes = | ||
} | ||
if (dynamicNode == null) { | ||
if (dynamicNode == null && !disableWarning) { | ||
console.warn( | ||
@@ -327,0 +381,0 @@ `This model execution did not contain any nodes with control flow ` + |
@@ -29,4 +29,4 @@ /** | ||
/** | ||
* A `tf.GraphModel` is a directed, acyclic graph of built from | ||
* SavedModel GraphDef and allows inference exeuction. | ||
* A `tf.GraphModel` is a directed, acyclic graph built from a | ||
* SavedModel GraphDef and allows inference execution. | ||
* | ||
@@ -118,3 +118,14 @@ * A `tf.GraphModel` can only be created by loading from a model converted from | ||
} | ||
this.artifacts = await this.handler.load(); | ||
const artifacts = await this.handler.load(); | ||
return this.loadSync(artifacts); | ||
} | ||
/** | ||
* Synchronously construct the in memory weight map and | ||
* compile the inference graph. | ||
*/ | ||
/** @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true} */ | ||
loadSync(artifacts:io.ModelArtifacts) { | ||
this.artifacts = artifacts; | ||
const graph = this.artifacts.modelTopology as tensorflow.IGraphDef; | ||
@@ -121,0 +132,0 @@ let signature = {}; |
@@ -131,3 +131,6 @@ /** | ||
const CONTROL_FLOW_OPS = ['Switch', 'Merge', 'Enter', 'Exit', 'NextIteration']; | ||
const CONTROL_FLOW_OPS = [ | ||
'Switch', 'Merge', 'Enter', 'Exit', 'NextIteration', 'StatelessIf', | ||
'StatelessWhile' | ||
]; | ||
const DYNAMIC_SHAPE_OPS = [ | ||
@@ -134,0 +137,0 @@ 'NonMaxSuppressionV2', 'NonMaxSuppressionV3', 'NonMaxSuppressionV5', 'Where' |
@@ -18,3 +18,4 @@ /** | ||
import {concat, DataType, slice, stack, Tensor, tensor, tidy, unstack, util} from '@tensorflow/tfjs-core'; | ||
import {concat, DataType, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core'; | ||
import {assertShapesMatchAllowUndefinedSize} from './tensor_utils'; | ||
@@ -129,3 +130,3 @@ export interface TensorWithState { | ||
this.assertShapesMatchAllowUndefinedSize( | ||
assertShapesMatchAllowUndefinedSize( | ||
this.elementShape, tensor.shape, | ||
@@ -199,3 +200,3 @@ `TensorArray ${this.name}: Could not write to TensorArray index ${ | ||
this.assertShapesMatchAllowUndefinedSize( | ||
assertShapesMatchAllowUndefinedSize( | ||
this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: '); | ||
@@ -226,3 +227,3 @@ | ||
this.assertShapesMatchAllowUndefinedSize( | ||
assertShapesMatchAllowUndefinedSize( | ||
this.elementShape, tensors[0].shape, | ||
@@ -310,27 +311,2 @@ `TensorArray shape mismatch: tensor array shape (${ | ||
} | ||
/** | ||
* This differs from util.assertShapesMatch in that it allows values of | ||
* negative one, an undefined size of a dimensinon, in a shape to match | ||
* anything. | ||
*/ | ||
private assertShapesMatchAllowUndefinedSize( | ||
shapeA: number[], shapeB: number[], errorMessagePrefix = ''): void { | ||
util.assert( | ||
this.shapesEqualAllowUndefinedSize(shapeA, shapeB), | ||
() => | ||
errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`); | ||
} | ||
private shapesEqualAllowUndefinedSize(n1: number[], n2: number[]) { | ||
if (n1.length !== n2.length) { | ||
return false; | ||
} | ||
for (let i = 0; i < n1.length; i++) { | ||
if (n1[i] !== -1 && n2[i] !== -1 && n1[i] !== n2[i]) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
} |
@@ -32,2 +32,41 @@ /** | ||
switch (node.op) { | ||
case 'If': | ||
case 'StatelessIf': { | ||
const thenFunc = | ||
getParamValue('thenBranch', node, tensorMap, context) as string; | ||
const elseFunc = | ||
getParamValue('elseBranch', node, tensorMap, context) as string; | ||
const cond = | ||
getParamValue('cond', node, tensorMap, context) as tfc.Tensor; | ||
const args = | ||
getParamValue('args', node, tensorMap, context) as tfc.Tensor[]; | ||
const condValue = await cond.data(); | ||
if (condValue[0]) { | ||
return context.functionMap[thenFunc].executeFunctionAsync(args); | ||
} else { | ||
return context.functionMap[elseFunc].executeFunctionAsync(args); | ||
} | ||
} | ||
case 'While': | ||
case 'StatelessWhile': { | ||
const bodyFunc = | ||
getParamValue('body', node, tensorMap, context) as string; | ||
const condFunc = | ||
getParamValue('cond', node, tensorMap, context) as string; | ||
const args = | ||
getParamValue('args', node, tensorMap, context) as tfc.Tensor[]; | ||
const condTensor = | ||
(await context.functionMap[condFunc].executeFunctionAsync(args))[0]; | ||
let condValue = await condTensor.data(); | ||
let result: tfc.Tensor[] = args; | ||
while (condValue[0]) { | ||
result = | ||
await context.functionMap[bodyFunc].executeFunctionAsync(result); | ||
const condTensor = | ||
(await context.functionMap[condFunc].executeFunctionAsync( | ||
result))[0]; | ||
condValue = await condTensor.data(); | ||
} | ||
return result; | ||
} | ||
case 'LoopCond': | ||
@@ -34,0 +73,0 @@ return [ |
@@ -27,5 +27,5 @@ /** | ||
export const executeOp: InternalOpExecutor = (node: Node, | ||
tensorMap: NamedTensorsMap, | ||
context: ExecutionContext): | ||
tfc.Tensor[] => { | ||
tensorMap: NamedTensorsMap, | ||
context: ExecutionContext): | ||
tfc.Tensor[] => { | ||
switch (node.op) { | ||
@@ -98,2 +98,12 @@ case 'Max': { | ||
} | ||
case 'Cumsum': { | ||
const axis = getParamValue('axis', node, tensorMap, context) as number; | ||
const exclusive = | ||
getParamValue('exclusive', node, tensorMap, context) as boolean; | ||
const reverse = | ||
getParamValue('reverse', node, tensorMap, context) as boolean; | ||
return [tfc.cumsum( | ||
getParamValue('x', node, tensorMap, context) as tfc.Tensor, axis, | ||
exclusive, reverse)]; | ||
} | ||
default: | ||
@@ -100,0 +110,0 @@ throw TypeError(`Node type ${node.op} is not implemented`); |
@@ -27,5 +27,5 @@ /** | ||
export const executeOp: InternalOpExecutor = (node: Node, | ||
tensorMap: NamedTensorsMap, | ||
context: ExecutionContext): | ||
tfc.Tensor[] => { | ||
tensorMap: NamedTensorsMap, | ||
context: ExecutionContext): | ||
tfc.Tensor[] => { | ||
switch (node.op) { | ||
@@ -92,2 +92,7 @@ case 'Cast': { | ||
} | ||
case 'BroadcastTo': { | ||
return [tfc.broadcastTo( | ||
getParamValue('x', node, tensorMap, context) as tfc.Tensor, | ||
getParamValue('shape', node, tensorMap, context) as number[])]; | ||
} | ||
default: | ||
@@ -94,0 +99,0 @@ throw TypeError(`Node type ${node.op} is not implemented`); |
@@ -103,9 +103,9 @@ /** | ||
export function parseNodeName(name: string): [string, number] { | ||
const index = name.lastIndexOf(':'); | ||
if (index === -1) { | ||
const parts = name.split(':'); | ||
if (parts.length === 1) { | ||
return [name, 0]; | ||
} | ||
const nodeName = name.substring(0, index); | ||
return [nodeName, Number(name.substring(index + 1))]; | ||
const nodeName = parts[0]; | ||
return [nodeName, Number(parts[parts.length - 1])]; | ||
} | ||
@@ -112,0 +112,0 @@ |
@@ -37,4 +37,3 @@ import {OpMapper} from '../types'; | ||
'category': 'control', | ||
'inputs': | ||
[{'start': 0, 'end': 0, 'name': 'tensors', 'type': 'tensors'}] | ||
'inputs': [{'start': 0, 'end': 0, 'name': 'tensors', 'type': 'tensors'}] | ||
}, | ||
@@ -183,3 +182,49 @@ { | ||
'inputs': [{'start': 0, 'name': 'tensorArrayId', 'type': 'number'}] | ||
}, | ||
{ | ||
'tfOpName': 'StatelessIf', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'name': 'cond', 'type': 'tensor'}, | ||
{'start': 1, 'end': 0, 'name': 'args', 'type': 'tensors'} | ||
], | ||
'attrs': [ | ||
{'tfName': 'then_branch', 'name': 'thenBranch', 'type': 'func'}, | ||
{'tfName': 'else_branch', 'name': 'elseBranch', 'type': 'func'} | ||
] | ||
}, | ||
{ | ||
'tfOpName': 'If', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'name': 'cond', 'type': 'tensor'}, | ||
{'start': 1, 'end': 0, 'name': 'args', 'type': 'tensors'} | ||
], | ||
'attrs': [ | ||
{'tfName': 'then_branch', 'name': 'thenBranch', 'type': 'func'}, | ||
{'tfName': 'else_branch', 'name': 'elseBranch', 'type': 'func'} | ||
] | ||
}, | ||
{ | ||
'tfOpName': 'StatelessWhile', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'end': 0, 'name': 'args', 'type': 'tensors'}, | ||
], | ||
'attrs': [ | ||
{'tfName': 'cond', 'name': 'cond', 'type': 'func'}, | ||
{'tfName': 'body', 'name': 'body', 'type': 'func'} | ||
] | ||
}, | ||
{ | ||
'tfOpName': 'While', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'end': 0, 'name': 'args', 'type': 'tensors'}, | ||
], | ||
'attrs': [ | ||
{'tfName': 'cond', 'name': 'cond', 'type': 'func'}, | ||
{'tfName': 'body', 'name': 'body', 'type': 'func'} | ||
] | ||
} | ||
]; |
@@ -150,2 +150,8 @@ import {OpMapper} from '../types'; | ||
}, | ||
{ | ||
'tfName': 'explicit_paddings', | ||
'name': 'explicitPaddings', | ||
'type': 'number[]', | ||
'defaultValue': [] | ||
}, | ||
{'tfName': 'dilations', 'name': 'dilations', 'type': 'number[]'} | ||
@@ -215,3 +221,4 @@ ] | ||
{'tfName': 'strides', 'name': 'strides', 'type': 'number[]'}, | ||
{'tfName': 'padding', 'name': 'pad', 'type': 'string'}, { | ||
{'tfName': 'padding', 'name': 'pad', 'type': 'string'}, | ||
{ | ||
'tfName': 'data_format', | ||
@@ -221,3 +228,9 @@ 'name': 'dataFormat', | ||
'notSupported': true | ||
} | ||
}, | ||
{ | ||
'tfName': 'explicit_paddings', | ||
'name': 'explicitPaddings', | ||
'type': 'number[]', | ||
'defaultValue': [] | ||
}, | ||
] | ||
@@ -240,2 +253,8 @@ }, | ||
}, | ||
{ | ||
'tfName': 'explicit_paddings', | ||
'name': 'explicitPaddings', | ||
'type': 'number[]', | ||
'defaultValue': [] | ||
}, | ||
{'tfName': 'dilations', 'name': 'dilations', 'type': 'number[]'} | ||
@@ -259,2 +278,8 @@ ] | ||
}, | ||
{ | ||
'tfName': 'explicit_paddings', | ||
'name': 'explicitPaddings', | ||
'type': 'number[]', | ||
'defaultValue': [] | ||
}, | ||
{'tfName': 'dilations', 'name': 'dilations', 'type': 'number[]'} | ||
@@ -261,0 +286,0 @@ ] |
@@ -99,3 +99,15 @@ import {OpMapper} from '../types'; | ||
'attrs': [{'tfName': 'keep_dims', 'name': 'keepDims', 'type': 'bool'}] | ||
}, | ||
{ | ||
'tfOpName': 'Cumsum', | ||
'category': 'reduction', | ||
'inputs': [ | ||
{'start': 0, 'name': 'x', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'axis', 'type': 'number'}, | ||
], | ||
'attrs': [ | ||
{'tfName': 'exclusive', 'name': 'exclusive', 'type': 'bool'}, | ||
{'tfName': 'reverse', 'name': 'reverse', 'type': 'bool'} | ||
] | ||
} | ||
]; |
@@ -121,3 +121,12 @@ import {OpMapper} from '../types'; | ||
] | ||
}, | ||
{ | ||
'tfOpName': 'BroadcastTo', | ||
'category': 'transformation', | ||
'inputs': [ | ||
{'start': 0, 'name': 'x', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'shape', 'type': 'number[]'}, | ||
], | ||
'attrs': [] | ||
} | ||
]; |
@@ -72,3 +72,3 @@ /** | ||
'number', 'string', 'number[]', 'bool', 'shape', 'tensor', 'tensors', | ||
'dtype', 'string[]' | ||
'dtype', 'string[]', 'func', 'dtype[]' | ||
] | ||
@@ -75,0 +75,0 @@ }, |
@@ -139,3 +139,19 @@ /** | ||
return {nodes, inputs, outputs, weights, placeholders, signature}; | ||
let functions = {}; | ||
if (graph.library != null && graph.library.function != null) { | ||
functions = graph.library.function.reduce((functions, func) => { | ||
functions[func.signature.name] = this.mapFunction(func); | ||
return functions; | ||
}, {} as {[key: string]: Graph}); | ||
} | ||
return { | ||
nodes, | ||
inputs, | ||
outputs, | ||
weights, | ||
placeholders, | ||
signature, | ||
functions | ||
}; | ||
} | ||
@@ -286,2 +302,11 @@ | ||
break; | ||
case 'func': | ||
value = getFuncParam( | ||
node.attr, param.tfName, param.defaultValue as string); | ||
if (value === undefined && !!param.tfDeprecatedName) { | ||
value = getFuncParam( | ||
node.attr, param.tfDeprecatedName, | ||
param.defaultValue as string); | ||
} | ||
break; | ||
case 'tensor': | ||
@@ -300,2 +325,91 @@ case 'tensors': | ||
} | ||
// map the TFunctionDef to TFJS graph object | ||
private mapFunction(functionDef: tensorflow.IFunctionDef): Graph { | ||
const tfNodes = functionDef.nodeDef; | ||
const placeholders: Node[] = []; | ||
const weights: Node[] = []; | ||
let nodes: {[key: string]: Node} = {}; | ||
if (tfNodes != null) { | ||
nodes = tfNodes.reduce<{[key: string]: Node}>((map, node) => { | ||
map[node.name] = this.mapNode(node); | ||
if (node.op === 'Const') { | ||
weights.push(map[node.name]); | ||
} | ||
return map; | ||
}, {}); | ||
} | ||
const inputs: Node[] = []; | ||
const outputs: Node[] = []; | ||
functionDef.signature.inputArg.forEach(arg => { | ||
const [nodeName, ] = getNodeNameAndIndex(arg.name); | ||
const node: Node = { | ||
name: nodeName, | ||
op: 'Placeholder', | ||
inputs: [], | ||
inputNames: [], | ||
category: 'graph', | ||
inputParams: {}, | ||
attrParams: {dtype: {value: parseDtypeParam(arg.type), type: 'dtype'}}, | ||
children: [] | ||
}; | ||
node.signatureKey = arg.name; | ||
inputs.push(node); | ||
nodes[nodeName] = node; | ||
}); | ||
const allNodes = Object.keys(nodes); | ||
allNodes.forEach(key => { | ||
const node = nodes[key]; | ||
node.inputNames.forEach(name => { | ||
const [nodeName, ] = getNodeNameAndIndex(name); | ||
node.inputs.push(nodes[nodeName]); | ||
nodes[nodeName].children.push(node); | ||
}); | ||
}); | ||
const returnNodeMap = functionDef.ret; | ||
functionDef.signature.outputArg.forEach(output => { | ||
const [nodeName, index] = getNodeNameAndIndex(returnNodeMap[output.name]); | ||
const node = nodes[nodeName]; | ||
if (node != null) { | ||
node.defaultOutput = index; | ||
outputs.push(node); | ||
} | ||
}); | ||
const signature = this.mapArgsToSignature(functionDef); | ||
return {nodes, inputs, outputs, weights, placeholders, signature}; | ||
} | ||
private mapArgsToSignature(functionDef: tensorflow.IFunctionDef): | ||
tensorflow.ISignatureDef { | ||
return { | ||
methodName: functionDef.signature.name, | ||
inputs: functionDef.signature.inputArg.reduce( | ||
(map, arg) => { | ||
map[arg.name] = this.mapArgToTensorInfo(arg); | ||
return map; | ||
}, | ||
{} as {[key: string]: tensorflow.ITensorInfo}), | ||
outputs: functionDef.signature.outputArg.reduce( | ||
(map, arg) => { | ||
map[arg.name] = this.mapArgToTensorInfo(arg, functionDef.ret); | ||
return map; | ||
}, | ||
{} as {[key: string]: tensorflow.ITensorInfo}), | ||
}; | ||
} | ||
private mapArgToTensorInfo( | ||
arg: tensorflow.OpDef.IArgDef, | ||
nameMap?: {[key: string]: string}): tensorflow.ITensorInfo { | ||
let name = arg.name; | ||
if (nameMap != null) { | ||
name = nameMap[name]; | ||
} | ||
return {name, dtype: arg.type}; | ||
} | ||
} | ||
@@ -374,2 +488,12 @@ | ||
export function getFuncParam( | ||
attrs: {[key: string]: tensorflow.IAttrValue}, name: string, | ||
def: string): string { | ||
const param = attrs[name]; | ||
if (param && param.func) { | ||
return param.func.name; | ||
} | ||
return def; | ||
} | ||
export function getDtypeParam( | ||
@@ -376,0 +500,0 @@ attrs: {[key: string]: tensorflow.IAttrValue}, name: string, |
@@ -24,3 +24,3 @@ /** | ||
export type ParamType = 'number'|'string'|'string[]'|'number[]'|'bool'|'bool[]'| | ||
'shape'|'shape[]'|'tensor'|'tensors'|'dtype'|'dtype[]'; | ||
'shape'|'shape[]'|'tensor'|'tensors'|'dtype'|'dtype[]'|'func'; | ||
export type Category = | ||
@@ -106,2 +106,3 @@ 'arithmetic'|'basic_math'|'control'|'convolution'|'custom'|'dynamic'| | ||
rawAttrs?: {[k: string]: tensorflow.IAttrValue}; | ||
defaultOutput?: number; | ||
} | ||
@@ -116,2 +117,3 @@ | ||
signature?: tensorflow.ISignatureDef; | ||
functions?: {[key: string]: Graph}; | ||
} | ||
@@ -118,0 +120,0 @@ |
/** @license See the LICENSE file. */ | ||
// This code is auto-generated, do not modify this file! | ||
const version = '2.0.0'; | ||
const version = '2.0.1'; | ||
export {version}; |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
5330403
260
39040
504