@tensorflow/tfjs-converter
Advanced tools
Comparing version 2.0.1 to 2.1.0
@@ -19,2 +19,3 @@ /** | ||
import { TensorArray } from '../executor/tensor_array'; | ||
import { TensorList } from '../executor/tensor_list'; | ||
export declare type NamedTensorMap = { | ||
@@ -29,2 +30,5 @@ [key: string]: Tensor; | ||
}; | ||
export declare type TensorListMap = { | ||
[key: number]: TensorList; | ||
}; | ||
export interface TensorInfo { | ||
@@ -31,0 +35,0 @@ name: string; |
@@ -18,4 +18,5 @@ /** | ||
import { Tensor } from '@tensorflow/tfjs-core'; | ||
import { NamedTensorsMap, TensorArrayMap } from '../data/types'; | ||
import { NamedTensorsMap, TensorArrayMap, TensorListMap } from '../data/types'; | ||
import { TensorArray } from './tensor_array'; | ||
import { TensorList } from './tensor_list'; | ||
import { FunctionExecutor } from './types'; | ||
@@ -39,2 +40,3 @@ export interface ExecutionContextInfo { | ||
readonly tensorArrayMap: TensorArrayMap; | ||
readonly tensorListMap: TensorListMap; | ||
readonly functionMap: { | ||
@@ -47,3 +49,3 @@ [key: string]: FunctionExecutor; | ||
private _currentContextIds; | ||
constructor(weightMap: NamedTensorsMap, tensorArrayMap: TensorArrayMap, functionMap?: { | ||
constructor(weightMap?: NamedTensorsMap, tensorArrayMap?: TensorArrayMap, tensorListMap?: TensorListMap, functionMap?: { | ||
[key: string]: FunctionExecutor; | ||
@@ -87,2 +89,5 @@ }); | ||
getTensorArray(id: number): TensorArray; | ||
addTensorList(tensorList: TensorList): void; | ||
getTensorList(id: number): TensorList; | ||
dispose(): void; | ||
} |
@@ -11,5 +11,6 @@ /** | ||
export class ExecutionContext { | ||
constructor(weightMap, tensorArrayMap, functionMap = {}) { | ||
constructor(weightMap = {}, tensorArrayMap = {}, tensorListMap = {}, functionMap = {}) { | ||
this.weightMap = weightMap; | ||
this.tensorArrayMap = tensorArrayMap; | ||
this.tensorListMap = tensorListMap; | ||
this.functionMap = functionMap; | ||
@@ -122,3 +123,17 @@ this.rootContext = { id: 0, frameName: '', iterationId: 0 }; | ||
} | ||
addTensorList(tensorList) { | ||
this.tensorListMap[tensorList.id] = tensorList; | ||
} | ||
getTensorList(id) { | ||
return this.tensorListMap[id]; | ||
} | ||
dispose() { | ||
for (const key in this.tensorArrayMap) { | ||
this.tensorArrayMap[key].clearAndClose(); | ||
} | ||
for (const key in this.tensorListMap) { | ||
this.tensorListMap[key].clearAndClose(); | ||
} | ||
} | ||
} | ||
//# sourceMappingURL=execution_context.js.map |
@@ -19,3 +19,3 @@ /** | ||
import { ISignatureDef } from '../data/compiled_api'; | ||
import { NamedTensorsMap, TensorInfo } from '../data/types'; | ||
import { NamedTensorsMap, TensorArrayMap, TensorInfo, TensorListMap } from '../data/types'; | ||
import { Graph } from '../operations/types'; | ||
@@ -82,8 +82,21 @@ import { FunctionExecutor } from './types'; | ||
* array. | ||
* @param disableWarning disable the no dynamic ops warning message, default | ||
* to false | ||
*/ | ||
executeAsync(inputs: NamedTensorMap, outputs: string[], disableWarning?: boolean): Promise<Tensor[]>; | ||
executeFunctionAsync(inputs: Tensor[]): Promise<Tensor[]>; | ||
executeAsync(inputs: NamedTensorMap, outputs: string[]): Promise<Tensor[]>; | ||
/** | ||
* Executes the inference for given input tensors in Async fashion. | ||
* @param inputs Tensor map for the model inputs, keyed by the input node | ||
* names. | ||
* @param outputs output node name from the Tensorflow model, if no outputs | ||
* are specified, the default outputs of the model would be used. You can | ||
* inspect intermediate nodes of the model by adding them to the outputs | ||
* array. | ||
* @param isFunctionExecution Flag for executing a function. | ||
* @param tensorArrayMap Optional, global TensorArray map by id. Used for | ||
* function execution. | ||
* @param tensorArrayMap Optinal global TensorList map by id. Used for | ||
* function execution. | ||
*/ | ||
private _executeAsync; | ||
executeFunctionAsync(inputs: Tensor[], tensorArrayMap: TensorArrayMap, tensorListMap: TensorListMap): Promise<Tensor[]>; | ||
/** | ||
* When there are control flow nodes in the graph, the graph execution use | ||
@@ -93,3 +106,3 @@ * ExecutionContext to keep track of the frames and loop iterators. | ||
* @param context the execution context object for current execution. | ||
* @param disableWarning disable no async op warning | ||
* @param isFunctionExecution Flag for executing a function. | ||
*/ | ||
@@ -96,0 +109,0 @@ private executeWithControlFlow; |
@@ -160,4 +160,5 @@ /** | ||
const tensorArrayMap = {}; | ||
const tensorListMap = {}; | ||
return tidy(() => { | ||
const context = new ExecutionContext(this.weightMap, tensorArrayMap, this.functionExecutorMap); | ||
const context = new ExecutionContext(this.weightMap, tensorArrayMap, tensorListMap, this.functionExecutorMap); | ||
const tensorsMap = Object.assign({}, this.weightMap); | ||
@@ -184,2 +185,6 @@ Object.keys(inputs).forEach(name => { | ||
} | ||
// dispose the context for the root executor | ||
if (this.parent == null) { | ||
context.dispose(); | ||
} | ||
return outputs.map(name => getTensor(name, tensorsMap, context)); | ||
@@ -239,17 +244,33 @@ }); | ||
* array. | ||
* @param disableWarning disable the no dynamic ops warning message, default | ||
* to false | ||
*/ | ||
async executeAsync(inputs, outputs, disableWarning = false) { | ||
inputs = this.mapInputs(inputs); | ||
this.checkInputs(inputs); | ||
this.checkInputShapeAndType(inputs); | ||
outputs = this.mapOutputs(outputs); | ||
this.checkOutputs(outputs); | ||
const tensorArrayMap = {}; | ||
const context = new ExecutionContext(this.weightMap, tensorArrayMap, this.functionExecutorMap); | ||
async executeAsync(inputs, outputs) { | ||
return this._executeAsync(inputs, outputs); | ||
} | ||
/** | ||
* Executes the inference for given input tensors in Async fashion. | ||
* @param inputs Tensor map for the model inputs, keyed by the input node | ||
* names. | ||
* @param outputs output node name from the Tensorflow model, if no outputs | ||
* are specified, the default outputs of the model would be used. You can | ||
* inspect intermediate nodes of the model by adding them to the outputs | ||
* array. | ||
* @param isFunctionExecution Flag for executing a function. | ||
* @param tensorArrayMap Optional, global TensorArray map by id. Used for | ||
* function execution. | ||
* @param tensorArrayMap Optinal global TensorList map by id. Used for | ||
* function execution. | ||
*/ | ||
async _executeAsync(inputs, outputs, isFunctionExecution = false, tensorArrayMap = {}, tensorListMap = {}) { | ||
if (!isFunctionExecution) { | ||
inputs = this.mapInputs(inputs); | ||
this.checkInputs(inputs); | ||
this.checkInputShapeAndType(inputs); | ||
outputs = this.mapOutputs(outputs); | ||
this.checkOutputs(outputs); | ||
} | ||
const context = new ExecutionContext(this.weightMap, tensorArrayMap, tensorListMap, this.functionExecutorMap); | ||
// Graph with control flow op requires runtime evaluation of the execution | ||
// order, while without control flow the execution order is pre-determined | ||
// in the compile method. | ||
const tensorMap = await this.executeWithControlFlow(inputs, context, outputs, disableWarning); | ||
const tensorMap = await this.executeWithControlFlow(inputs, context, outputs, isFunctionExecution); | ||
const results = outputs.map(name => getTensor(name, tensorMap, context)); | ||
@@ -269,5 +290,9 @@ // dispose all the intermediate tensors | ||
}); | ||
// dispose the context for the root executor | ||
if (this.parent == null) { | ||
context.dispose(); | ||
} | ||
return results; | ||
} | ||
async executeFunctionAsync(inputs) { | ||
async executeFunctionAsync(inputs, tensorArrayMap, tensorListMap) { | ||
const mappedInputs = inputs.reduce((map, tensor, index) => { | ||
@@ -277,3 +302,3 @@ map[this.inputs[index].name] = tensor; | ||
}, {}); | ||
return this.executeAsync(mappedInputs, this.outputNodes, true); | ||
return this._executeAsync(mappedInputs, this.outputNodes, true, tensorArrayMap, tensorListMap); | ||
} | ||
@@ -285,5 +310,5 @@ /** | ||
* @param context the execution context object for current execution. | ||
* @param disableWarning disable no async op warning | ||
* @param isFunctionExecution Flag for executing a function. | ||
*/ | ||
async executeWithControlFlow(inputs, context, outputNames, disableWarning) { | ||
async executeWithControlFlow(inputs, context, outputNames, isFunctionExecution) { | ||
const names = Object.keys(inputs); | ||
@@ -310,3 +335,3 @@ const inputNodes = names.map(name => this.graph.nodes[parseNodeName(name)[0]]); | ||
} | ||
if (dynamicNode == null && !disableWarning) { | ||
if (dynamicNode == null && !isFunctionExecution) { | ||
console.warn(`This model execution did not contain any nodes with control flow ` + | ||
@@ -313,0 +338,0 @@ `or dynamic output shapes. You can use model.execute() instead.`); |
@@ -108,3 +108,3 @@ /** | ||
'Switch', 'Merge', 'Enter', 'Exit', 'NextIteration', 'StatelessIf', | ||
'StatelessWhile' | ||
'StatelessWhile', 'if', 'While' | ||
]; | ||
@@ -111,0 +111,0 @@ const DYNAMIC_SHAPE_OPS = [ |
@@ -36,10 +36,10 @@ /** | ||
readonly clearAfterRead: boolean; | ||
private static nextId; | ||
private tensors; | ||
private closed_; | ||
readonly idTensor: Tensor; | ||
constructor(name: string, dtype: DataType, maxSize: number, elementShape: number[], identicalElementShapes: boolean, dynamicSize: boolean, clearAfterRead: boolean); | ||
readonly id: number; | ||
constructor(name: string, dtype: DataType, maxSize: number, elementShape: number[], identicalElementShapes: boolean, dynamicSize: boolean, clearAfterRead: boolean); | ||
readonly closed: boolean; | ||
/** | ||
* Close the current TensorArray. | ||
* Dispose the tensors and idTensor and mark the TensoryArray as closed. | ||
*/ | ||
@@ -46,0 +46,0 @@ clearAndClose(): void; |
@@ -17,3 +17,3 @@ /** | ||
*/ | ||
import { concat, slice, stack, tensor, tidy, unstack } from '@tensorflow/tfjs-core'; | ||
import { concat, keep, scalar, slice, stack, tensor, tidy, unstack } from '@tensorflow/tfjs-core'; | ||
import { assertShapesMatchAllowUndefinedSize } from './tensor_utils'; | ||
@@ -35,4 +35,8 @@ /** | ||
this.closed_ = false; | ||
this.id = TensorArray.nextId++; | ||
this.idTensor = scalar(0); | ||
keep(this.idTensor); | ||
} | ||
get id() { | ||
return this.idTensor.id; | ||
} | ||
get closed() { | ||
@@ -42,3 +46,3 @@ return this.closed_; | ||
/** | ||
* Close the current TensorArray. | ||
* Dispose the tensors and idTensor and mark the TensoryArray as closed. | ||
*/ | ||
@@ -49,2 +53,3 @@ clearAndClose() { | ||
this.closed_ = true; | ||
this.idTensor.dispose(); | ||
} | ||
@@ -105,9 +110,10 @@ size() { | ||
assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, `TensorArray ${this.name}: Could not write to TensorArray index ${index}.`); | ||
if (t && t.read) { | ||
if (t.read) { | ||
throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been read.`); | ||
} | ||
if (t && t.written) { | ||
if (t.written) { | ||
throw new Error(`TensorArray ${this.name}: Could not write to TensorArray index ${index}, because it has already been written.`); | ||
} | ||
t.tensor = tensor; | ||
keep(tensor); | ||
t.written = true; | ||
@@ -237,3 +243,2 @@ this.tensors[index] = t; | ||
} | ||
TensorArray.nextId = 0; | ||
//# sourceMappingURL=tensor_array.js.map |
@@ -33,6 +33,8 @@ /** | ||
export declare class TensorList { | ||
tensors: Tensor[]; | ||
elementShape: number[]; | ||
elementDtype: DataType; | ||
readonly tensors: Tensor[]; | ||
readonly elementShape: number[]; | ||
readonly elementDtype: DataType; | ||
readonly idTensor: Tensor; | ||
maxNumElements: number; | ||
readonly id: number; | ||
/** | ||
@@ -52,2 +54,6 @@ * | ||
/** | ||
* Dispose the tensors and idTensor and clear the tensor list. | ||
*/ | ||
clearAndClose(): void; | ||
/** | ||
* The size of the tensors in the tensor list. | ||
@@ -113,3 +119,3 @@ */ | ||
*/ | ||
export declare function fromTensor(tensor: Tensor, elementShape: number[]): TensorList; | ||
export declare function fromTensor(tensor: Tensor, elementShape: number[], elementDtype: DataType): TensorList; | ||
/** | ||
@@ -129,3 +135,3 @@ * Return a TensorList of the given size with empty elements. | ||
*/ | ||
export declare function scatter(tensor: Tensor, indices: number[], elementShape: number[], numElements: number): TensorList; | ||
export declare function scatter(tensor: Tensor, indices: number[], elementShape: number[], numElements?: number): TensorList; | ||
/** | ||
@@ -132,0 +138,0 @@ * Split the values of a Tensor into a TensorList. |
@@ -17,3 +17,3 @@ /** | ||
*/ | ||
import { concat, slice, stack, tensor, tidy, unstack } from '@tensorflow/tfjs-core'; | ||
import { concat, keep, scalar, slice, stack, tensor, tidy, unstack } from '@tensorflow/tfjs-core'; | ||
import { assertShapesMatchAllowUndefinedSize } from './tensor_utils'; | ||
@@ -47,4 +47,18 @@ /** | ||
this.elementDtype = elementDtype; | ||
if (tensors != null) { | ||
tensors.forEach(tensor => { | ||
if (elementDtype !== tensor.dtype) { | ||
throw new Error(`Invalid data types; op elements ${elementDtype}, but list elements ${tensor.dtype}`); | ||
} | ||
assertShapesMatchAllowUndefinedSize(elementShape, tensor.shape, 'TensorList shape mismatch: '); | ||
keep(tensor); | ||
}); | ||
} | ||
this.idTensor = scalar(0); | ||
this.maxNumElements = maxNumElements; | ||
keep(this.idTensor); | ||
} | ||
get id() { | ||
return this.idTensor.id; | ||
} | ||
/** | ||
@@ -57,2 +71,10 @@ * Get a new TensorList containing a copy of the underlying tensor container. | ||
/** | ||
* Dispose the tensors and idTensor and clear the tensor list. | ||
*/ | ||
clearAndClose() { | ||
this.tensors.forEach(tensor => tensor.dispose()); | ||
this.tensors.length = 0; | ||
this.idTensor.dispose(); | ||
} | ||
/** | ||
* The size of the tensors in the tensor list. | ||
@@ -111,2 +133,3 @@ */ | ||
} | ||
keep(tensor); | ||
this.tensors.push(tensor); | ||
@@ -160,2 +183,3 @@ } | ||
assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, 'TensorList shape mismatch: '); | ||
keep(tensor); | ||
this.tensors[elementIndex] = tensor; | ||
@@ -210,3 +234,3 @@ } | ||
*/ | ||
export function fromTensor(tensor, elementShape) { | ||
export function fromTensor(tensor, elementShape, elementDtype) { | ||
const dtype = tensor.dtype; | ||
@@ -216,9 +240,8 @@ if (tensor.shape.length < 1) { | ||
} | ||
if (tensor.dtype !== elementDtype) { | ||
throw new Error(`Invalid data types; op elements ${tensor.dtype}, but list elements ${elementDtype}`); | ||
} | ||
const outputShape = tensor.shape.slice(1); | ||
assertShapesMatchAllowUndefinedSize(outputShape, elementShape, 'TensorList shape mismatch: '); | ||
const tensorList = []; | ||
for (let i = 0; i < tensor.shape[0]; ++i) { | ||
const tmp = tensor.slice(i, i + 1).reshape(outputShape); | ||
tensorList.push(tmp); | ||
} | ||
const tensorList = tensor.unstack(); | ||
return new TensorList(tensorList, elementShape, dtype); | ||
@@ -247,3 +270,3 @@ } | ||
const maxIndex = Math.max(...indices); | ||
if (numElements !== -1 && maxIndex >= numElements) { | ||
if (numElements != null && numElements !== -1 && maxIndex >= numElements) { | ||
throw new Error(`Max index must be < array size (${maxIndex} vs. ${numElements})`); | ||
@@ -286,2 +309,3 @@ } | ||
} | ||
tensor.dispose(); | ||
return tensors; | ||
@@ -288,0 +312,0 @@ }); |
@@ -18,6 +18,9 @@ /** | ||
import { Tensor } from '@tensorflow/tfjs-core'; | ||
import { NamedTensorsMap } from '../data/types'; | ||
import { NamedTensorsMap, TensorArrayMap, TensorListMap } from '../data/types'; | ||
/** | ||
* | ||
*/ | ||
export interface FunctionExecutor { | ||
executeFunctionAsync(inputs: Tensor[]): Promise<Tensor[]>; | ||
executeFunctionAsync(inputs: Tensor[], tensorArrayMap: TensorArrayMap, tensorListMap: TensorListMap): Promise<Tensor[]>; | ||
weightMap: NamedTensorsMap; | ||
} |
@@ -19,2 +19,3 @@ /** | ||
import { TensorArray } from '../../executor/tensor_array'; | ||
import { fromTensor, reserve, scatter, split } from '../../executor/tensor_list'; | ||
import { getParamValue, getTensor } from './utils'; | ||
@@ -31,6 +32,6 @@ export const executeOp = async (node, tensorMap, context) => { | ||
if (condValue[0]) { | ||
return context.functionMap[thenFunc].executeFunctionAsync(args); | ||
return context.functionMap[thenFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap); | ||
} | ||
else { | ||
return context.functionMap[elseFunc].executeFunctionAsync(args); | ||
return context.functionMap[elseFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap); | ||
} | ||
@@ -43,17 +44,45 @@ } | ||
const args = getParamValue('args', node, tensorMap, context); | ||
const condTensor = (await context.functionMap[condFunc].executeFunctionAsync(args))[0]; | ||
let condValue = await condTensor.data(); | ||
// Calculate the condition of the loop | ||
const condResult = (await context.functionMap[condFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap)); | ||
const argIds = args.map(tensor => tensor.id); | ||
let condValue = await condResult[0].data(); | ||
// Dispose the intermediate tensors for condition function | ||
condResult.forEach(tensor => { | ||
if (!tensor.kept && argIds.indexOf(tensor.id) === -1) { | ||
tensor.dispose(); | ||
} | ||
}); | ||
let result = args; | ||
while (condValue[0]) { | ||
result = | ||
await context.functionMap[bodyFunc].executeFunctionAsync(result); | ||
const condTensor = (await context.functionMap[condFunc].executeFunctionAsync(result))[0]; | ||
condValue = await condTensor.data(); | ||
// Record the previous result for intermediate tensor tracking | ||
const origResult = result; | ||
// Execution the body of the loop | ||
result = await context.functionMap[bodyFunc].executeFunctionAsync(result, context.tensorArrayMap, context.tensorListMap); | ||
const resultIds = result.map(tensor => tensor.id); | ||
// Dispose the intermediate tensor for body function that is not global | ||
// kept, not input/output of the body function | ||
origResult.forEach(tensor => { | ||
if (!tensor.kept && argIds.indexOf(tensor.id) === -1 && | ||
resultIds.indexOf(tensor.id) === -1) { | ||
tensor.dispose(); | ||
} | ||
}); | ||
// Recalcuate the condition of the loop using the latest results. | ||
const condResult = (await context.functionMap[condFunc].executeFunctionAsync(result, context.tensorArrayMap, context.tensorListMap)); | ||
condValue = await condResult[0].data(); | ||
// Dispose the intermediate tensors for condition function | ||
condResult.forEach(tensor => { | ||
if (!tensor.kept && argIds.indexOf(tensor.id) === -1 && | ||
resultIds.indexOf(tensor.id) === -1) { | ||
tensor.dispose(); | ||
} | ||
}); | ||
} | ||
return result; | ||
} | ||
case 'LoopCond': | ||
case 'LoopCond': { | ||
return [ | ||
getParamValue('pred', node, tensorMap, context).clone() | ||
]; | ||
} | ||
case 'Switch': { | ||
@@ -66,7 +95,8 @@ const pred = getParamValue('pred', node, tensorMap, context); | ||
} | ||
case 'Merge': | ||
case 'Merge': { | ||
const inputName = node.inputNames.find(name => getTensor(name, tensorMap, context) !== undefined); | ||
return inputName ? [getTensor(inputName, tensorMap, context).clone()] : | ||
undefined; | ||
case 'Enter': | ||
} | ||
case 'Enter': { | ||
const frameId = getParamValue('frameName', node, tensorMap, context); | ||
@@ -76,11 +106,14 @@ const data = getParamValue('tensor', node, tensorMap, context); | ||
return [data.clone()]; | ||
case 'Exit': | ||
} | ||
case 'Exit': { | ||
const tensor = getParamValue('tensor', node, tensorMap, context); | ||
context.exitFrame(); | ||
return [tensor.clone()]; | ||
case 'NextIteration': | ||
} | ||
case 'NextIteration': { | ||
const input = getParamValue('tensor', node, tensorMap, context); | ||
context.nextIteration(); | ||
return [input.clone()]; | ||
case 'TensorArrayV3': | ||
} | ||
case 'TensorArrayV3': { | ||
const size = getParamValue('size', node, tensorMap, context); | ||
@@ -95,49 +128,145 @@ const dtype = getParamValue('dtype', node, tensorMap, context); | ||
context.addTensorArray(tensorArray); | ||
return [scalar(tensorArray.id), scalar(1.0)]; | ||
case 'TensorArrayWriteV3': | ||
return [tensorArray.idTensor, scalar(1.0)]; | ||
} | ||
case 'TensorArrayWriteV3': { | ||
const id = getParamValue('tensorArrayId', node, tensorMap, context); | ||
const index = getParamValue('index', node, tensorMap, context); | ||
const writeTensor = getParamValue('tensor', node, tensorMap, context); | ||
const writeTensorArray = context.getTensorArray(id); | ||
const writeTensorArray = context.getTensorArray(id.id); | ||
writeTensorArray.write(index, writeTensor); | ||
return [scalar(1.0)]; | ||
case 'TensorArrayReadV3': | ||
return [writeTensorArray.idTensor]; | ||
} | ||
case 'TensorArrayReadV3': { | ||
const readId = getParamValue('tensorArrayId', node, tensorMap, context); | ||
const readIndex = getParamValue('index', node, tensorMap, context); | ||
const readTensorArray = context.getTensorArray(readId); | ||
const readTensorArray = context.getTensorArray(readId.id); | ||
return [readTensorArray.read(readIndex)]; | ||
case 'TensorArrayGatherV3': | ||
} | ||
case 'TensorArrayGatherV3': { | ||
const gatherId = getParamValue('tensorArrayId', node, tensorMap, context); | ||
const gatherIndices = getParamValue('indices', node, tensorMap, context); | ||
const gatherDtype = getParamValue('dtype', node, tensorMap, context); | ||
const gatherTensorArray = context.getTensorArray(gatherId); | ||
const gatherTensorArray = context.getTensorArray(gatherId.id); | ||
return [gatherTensorArray.gather(gatherIndices, gatherDtype)]; | ||
case 'TensorArrayScatterV3': | ||
} | ||
case 'TensorArrayScatterV3': { | ||
const scatterId = getParamValue('tensorArrayId', node, tensorMap, context); | ||
const scatterIndices = getParamValue('indices', node, tensorMap, context); | ||
const scatterTensor = getParamValue('tensor', node, tensorMap, context); | ||
const scatterTensorArray = context.getTensorArray(scatterId); | ||
const scatterTensorArray = context.getTensorArray(scatterId.id); | ||
scatterTensorArray.scatter(scatterIndices, scatterTensor); | ||
return [scalar(1.0)]; | ||
case 'TensorArrayConcatV3': | ||
return [scatterTensorArray.idTensor]; | ||
} | ||
case 'TensorArrayConcatV3': { | ||
const concatId = getParamValue('tensorArrayId', node, tensorMap, context); | ||
const concatTensorArray = context.getTensorArray(concatId); | ||
const concatTensorArray = context.getTensorArray(concatId.id); | ||
const concatDtype = getParamValue('dtype', node, tensorMap, context); | ||
return [concatTensorArray.concat(concatDtype)]; | ||
case 'TensorArraySplitV3': | ||
} | ||
case 'TensorArraySplitV3': { | ||
const splitId = getParamValue('tensorArrayId', node, tensorMap, context); | ||
const splitTensor = getParamValue('tensor', node, tensorMap, context); | ||
const lengths = getParamValue('lengths', node, tensorMap, context); | ||
const splitTensorArray = context.getTensorArray(splitId); | ||
const splitTensorArray = context.getTensorArray(splitId.id); | ||
splitTensorArray.split(lengths, splitTensor); | ||
return [scalar(1.0)]; | ||
case 'TensorArraySizeV3': | ||
return [splitTensorArray.idTensor]; | ||
} | ||
case 'TensorArraySizeV3': { | ||
const sizeId = getParamValue('tensorArrayId', node, tensorMap, context); | ||
const sizeTensorArray = context.getTensorArray(sizeId); | ||
const sizeTensorArray = context.getTensorArray(sizeId.id); | ||
return [scalar(sizeTensorArray.size(), 'int32')]; | ||
case 'TensorArrayCloseV3': | ||
} | ||
case 'TensorArrayCloseV3': { | ||
const closeId = getParamValue('tensorArrayId', node, tensorMap, context); | ||
const closeTensorArray = context.getTensorArray(closeId); | ||
const closeTensorArray = context.getTensorArray(closeId.id); | ||
closeTensorArray.clearAndClose(); | ||
return [scalar(0)]; | ||
return [closeTensorArray.idTensor]; | ||
} | ||
case 'TensorListSetItem': { | ||
const idTensor = getParamValue('tensorListId', node, tensorMap, context); | ||
const index = getParamValue('index', node, tensorMap, context); | ||
const writeTensor = getParamValue('tensor', node, tensorMap, context); | ||
const tensorList = context.getTensorList(idTensor.id); | ||
tensorList.setItem(index, writeTensor); | ||
return [tensorList.idTensor]; | ||
} | ||
case 'TensorListGetItem': { | ||
const idTensor = getParamValue('tensorListId', node, tensorMap, context); | ||
const readIndex = getParamValue('index', node, tensorMap, context); | ||
const elementShape = getParamValue('elementShape', node, tensorMap, context); | ||
const elementDType = getParamValue('elementDType', node, tensorMap, context); | ||
const tensorList = context.getTensorList(idTensor.id); | ||
return [tensorList.getItem(readIndex, elementShape, elementDType)]; | ||
} | ||
case 'TensorListScatterV2': | ||
case 'TensorListScatter': { | ||
const scatterIndices = getParamValue('indices', node, tensorMap, context); | ||
const scatterTensor = getParamValue('tensor', node, tensorMap, context); | ||
const elementShape = getParamValue('elementShape', node, tensorMap, context); | ||
const numElements = getParamValue('numElements', node, tensorMap, context); | ||
const tensorList = scatter(scatterTensor, scatterIndices, elementShape, numElements); | ||
context.addTensorList(tensorList); | ||
return [tensorList.idTensor]; | ||
} | ||
case 'TensorListReserve': { | ||
const elementShape = getParamValue('elementShape', node, tensorMap, context); | ||
const elementDtype = getParamValue('elementDType', node, tensorMap, context); | ||
const numElements = getParamValue('numElements', node, tensorMap, context); | ||
const tensorList = reserve(elementShape, elementDtype, numElements); | ||
context.addTensorList(tensorList); | ||
return [tensorList.idTensor]; | ||
} | ||
case 'TensorListGather': { | ||
const gatherId = getParamValue('tensorListId', node, tensorMap, context); | ||
const gatherIndices = getParamValue('indices', node, tensorMap, context); | ||
const elementShape = getParamValue('elementShape', node, tensorMap, context); | ||
const elementDtype = getParamValue('elementDType', node, tensorMap, context); | ||
const tensorList = context.getTensorList(gatherId.id); | ||
return [tensorList.gather(gatherIndices, elementDtype, elementShape)]; | ||
} | ||
case 'TensorListStack': { | ||
const idTensor = getParamValue('tensorListId', node, tensorMap, context); | ||
const elementShape = getParamValue('elementShape', node, tensorMap, context); | ||
const elementDtype = getParamValue('elementDType', node, tensorMap, context); | ||
const numElements = getParamValue('numElements', node, tensorMap, context); | ||
const tensorList = context.getTensorList(idTensor.id); | ||
return [tensorList.stack(elementShape, elementDtype, numElements)]; | ||
} | ||
case 'TensorListFromTensor': { | ||
const tensor = getParamValue('tensor', node, tensorMap, context); | ||
const elementShape = getParamValue('elementShape', node, tensorMap, context); | ||
const elementDtype = getParamValue('elementDType', node, tensorMap, context); | ||
const tensorList = fromTensor(tensor, elementShape, elementDtype); | ||
context.addTensorList(tensorList); | ||
return [tensorList.idTensor]; | ||
} | ||
case 'TensorListConcat': { | ||
const concatId = getParamValue('tensorListId', node, tensorMap, context); | ||
const tensorList = context.getTensorList(concatId.id); | ||
const concatDtype = getParamValue('dtype', node, tensorMap, context); | ||
const elementShape = getParamValue('elementShape', node, tensorMap, context); | ||
return [tensorList.concat(concatDtype, elementShape)]; | ||
} | ||
case 'TensorListPushBack': { | ||
const idTensor = getParamValue('tensorListId', node, tensorMap, context); | ||
const writeTensor = getParamValue('tensor', node, tensorMap, context); | ||
const tensorList = context.getTensorList(idTensor.id); | ||
tensorList.pushBack(writeTensor); | ||
return [tensorList.idTensor]; | ||
} | ||
case 'TensorListPopBack': { | ||
const idTensor = getParamValue('tensorListId', node, tensorMap, context); | ||
const elementShape = getParamValue('elementShape', node, tensorMap, context); | ||
const elementDType = getParamValue('elementDType', node, tensorMap, context); | ||
const tensorList = context.getTensorList(idTensor.id); | ||
return [tensorList.popBack(elementShape, elementDType)]; | ||
} | ||
case 'TensorListSplit': { | ||
const splitTensor = getParamValue('tensor', node, tensorMap, context); | ||
const elementShape = getParamValue('elementShape', node, tensorMap, context); | ||
const lengths = getParamValue('lengths', node, tensorMap, context); | ||
const tensorList = split(splitTensor, lengths, elementShape); | ||
context.addTensorList(tensorList); | ||
return [tensorList.idTensor]; | ||
} | ||
default: | ||
@@ -144,0 +273,0 @@ throw TypeError(`Node type ${node.op} is not implemented`); |
@@ -18,3 +18,3 @@ /** | ||
import * as tfc from '@tensorflow/tfjs-core'; | ||
import { getParamValue } from './utils'; | ||
import { getPadding, getParamValue } from './utils'; | ||
export const executeOp = (node, tensorMap, context) => { | ||
@@ -32,3 +32,3 @@ switch (node.op) { | ||
const stride = getParamValue('strides', node, tensorMap, context); | ||
const pad = getParamValue('pad', node, tensorMap, context); | ||
const pad = getPadding(node, tensorMap, context); | ||
const dataFormat = getParamValue('dataFormat', node, tensorMap, context) | ||
@@ -60,3 +60,3 @@ .toUpperCase(); | ||
const stride = getParamValue('strides', node, tensorMap, context); | ||
const pad = getParamValue('pad', node, tensorMap, context); | ||
const pad = getPadding(node, tensorMap, context); | ||
const dataFormat = getParamValue('dataFormat', node, tensorMap, context) | ||
@@ -85,3 +85,3 @@ .toUpperCase(); | ||
const stride = getParamValue('strides', node, tensorMap, context); | ||
const pad = getParamValue('pad', node, tensorMap, context); | ||
const pad = getPadding(node, tensorMap, context); | ||
return [tfc.conv2dTranspose(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), shape, [stride[1], stride[2]], pad)]; | ||
@@ -92,3 +92,3 @@ } | ||
const stride = getParamValue('strides', node, tensorMap, context); | ||
const pad = getParamValue('pad', node, tensorMap, context); | ||
const pad = getPadding(node, tensorMap, context); | ||
const dilations = getParamValue('dilations', node, tensorMap, context); | ||
@@ -139,2 +139,14 @@ const dataFormat = getParamValue('dataFormat', node, tensorMap, context) | ||
} | ||
case 'Dilation2D': { | ||
const strides = getParamValue('strides', node, tensorMap, context); | ||
const pad = getParamValue('pad', node, tensorMap, context); | ||
const dilations = getParamValue('dilations', node, tensorMap, context); | ||
// strides: [1, stride_height, stride_width, 1]. | ||
const strideHeight = strides[1]; | ||
const strideWidth = strides[2]; | ||
// dilations: [1, dilation_height, dilation_width, 1]. | ||
const dilationHeight = dilations[1]; | ||
const dilationWidth = dilations[2]; | ||
return [tfc.dilation2d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [strideHeight, strideWidth], pad, [dilationHeight, dilationWidth], 'NHWC' /* dataFormat */)]; | ||
} | ||
default: | ||
@@ -141,0 +153,0 @@ throw TypeError(`Node type ${node.op} is not implemented`); |
/** | ||
* @license | ||
* Copyright 2018 Google Inc. All Rights Reserved. | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
@@ -5,0 +5,0 @@ * you may not use this file except in compliance with the License. |
/** | ||
* @license | ||
* Copyright 2018 Google Inc. All Rights Reserved. | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
@@ -22,2 +22,3 @@ * you may not use this file except in compliance with the License. | ||
case 'NonMaxSuppressionV5': | ||
case 'NonMaxSuppressionV4': | ||
case 'NonMaxSuppressionV3': | ||
@@ -35,2 +36,7 @@ case 'NonMaxSuppressionV2': { | ||
} | ||
if (node.op === 'NonMaxSuppressionV4') { | ||
const padToMaxOutputSize = getParamValue('padToMaxOutputSize', node, tensorMap, context); | ||
const result = await tfc.image.nonMaxSuppressionPaddedAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize); | ||
return [result.selectedIndices, result.validOutputs]; | ||
} | ||
return [await tfc.image.nonMaxSuppressionAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold)]; | ||
@@ -37,0 +43,0 @@ } |
/** | ||
* @license | ||
* Copyright 2018 Google Inc. All Rights Reserved. | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
@@ -5,0 +5,0 @@ * you may not use this file except in compliance with the License. |
/** | ||
* @license | ||
* Copyright 2018 Google Inc. All Rights Reserved. | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
@@ -5,0 +5,0 @@ * you may not use this file except in compliance with the License. |
/** | ||
* @license | ||
* Copyright 2018 Google Inc. All Rights Reserved. | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
@@ -5,0 +5,0 @@ * you may not use this file except in compliance with the License. |
/** | ||
* @license | ||
* Copyright 2018 Google Inc. All Rights Reserved. | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
@@ -5,0 +5,0 @@ * you may not use this file except in compliance with the License. |
@@ -59,9 +59,2 @@ /** | ||
const tensor = getParamValue('x', node, tensorMap, context); | ||
if (begin.length === 1 && tensor.shape.length > 1) { | ||
for (let i = 1; i < tensor.shape.length; i++) { | ||
begin.push(0); | ||
end.push(tensor.shape[i]); | ||
strides.push(strides[0]); | ||
} | ||
} | ||
return [tfc.stridedSlice(tensor, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask)]; | ||
@@ -102,3 +95,4 @@ } | ||
const numOrSizeSplits = getParamValue('numOrSizeSplits', node, tensorMap, context); | ||
return tfc.split(getParamValue('x', node, tensorMap, context), numOrSizeSplits, axis); | ||
const tensor = getParamValue('x', node, tensorMap, context); | ||
return tfc.split(tensor, numOrSizeSplits, axis); | ||
} | ||
@@ -105,0 +99,0 @@ case 'ScatterNd': { |
@@ -25,2 +25,3 @@ /** | ||
export declare function createTensorShapeAttr(value: number[]): ParamValue; | ||
export declare function createShapeAttrFromIndex(inputIndex: number): InputParamValue; | ||
export declare function createNumericArrayAttr(value: number[]): ParamValue; | ||
@@ -27,0 +28,0 @@ export declare function createNumericArrayAttrFromIndex(inputIndex: number): InputParamValue; |
@@ -19,2 +19,5 @@ export function createNumberAttr(value) { | ||
} | ||
export function createShapeAttrFromIndex(inputIndex) { | ||
return { inputIndexStart: inputIndex, type: 'shape' }; | ||
} | ||
export function createNumericArrayAttr(value) { | ||
@@ -39,3 +42,3 @@ return { value, type: 'number[]' }; | ||
opMappers.find(mapper => mapper.tfOpName === node.op); | ||
return Object.keys(node.inputParams).every(key => { | ||
const matched = Object.keys(node.inputParams).every(key => { | ||
const value = node.inputParams[key]; | ||
@@ -51,3 +54,8 @@ const def = opMapper.inputs.find(param => param.name === key); | ||
}); | ||
if (!matched) { | ||
console.log('node = ', node); | ||
console.log('opMapper = ', opMapper); | ||
} | ||
return matched; | ||
} | ||
//# sourceMappingURL=test_helper.js.map |
@@ -18,3 +18,3 @@ /** | ||
import * as tfc from '@tensorflow/tfjs-core'; | ||
import { getParamValue, split } from './utils'; | ||
import { getParamValue } from './utils'; | ||
export const executeOp = (node, tensorMap, context) => { | ||
@@ -38,7 +38,7 @@ switch (node.op) { | ||
case 'Pad': { | ||
return [tfc.pad(getParamValue('x', node, tensorMap, context), split(getParamValue('padding', node, tensorMap, context), 2), getParamValue('constantValue', node, tensorMap, context))]; | ||
return [tfc.pad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('constantValue', node, tensorMap, context))]; | ||
} | ||
case 'SpaceToBatchND': { | ||
const blockShape = getParamValue('blockShape', node, tensorMap, context); | ||
const paddings = split(getParamValue('paddings', node, tensorMap, context), 2); | ||
const paddings = getParamValue('paddings', node, tensorMap, context); | ||
return [tfc.spaceToBatchND(getParamValue('x', node, tensorMap, context), blockShape, paddings)]; | ||
@@ -48,3 +48,3 @@ } | ||
const blockShape = getParamValue('blockShape', node, tensorMap, context); | ||
const crops = split(getParamValue('crops', node, tensorMap, context), 2); | ||
const crops = getParamValue('crops', node, tensorMap, context); | ||
return [tfc.batchToSpaceND(getParamValue('x', node, tensorMap, context), blockShape, crops)]; | ||
@@ -51,0 +51,0 @@ } |
@@ -44,1 +44,2 @@ /** | ||
export declare function split(arr: number[], size: number): number[][]; | ||
export declare function getPadding(node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext): ValueType; |
@@ -17,2 +17,3 @@ /** | ||
*/ | ||
import * as tfc from '@tensorflow/tfjs-core'; | ||
export function getParamValue(paramName, node, tensorMap, context) { | ||
@@ -33,5 +34,7 @@ const inputParam = node.inputParams[paramName]; | ||
} | ||
const data = Array.prototype.slice.call(getTensor(node.inputNames.slice(start)[0], tensorMap, context) | ||
.dataSync()); | ||
return inputParam.type === 'number' ? data[0] : data; | ||
const tensor = getTensor(node.inputNames.slice(start)[0], tensorMap, context); | ||
const data = tensor.dataSync(); | ||
return inputParam.type === 'number' ? | ||
data[0] : | ||
tfc.util.toNestedArray(tensor.shape, data); | ||
} | ||
@@ -95,2 +98,16 @@ const attrParam = node.attrParams[paramName]; | ||
} | ||
export function getPadding(node, tensorMap, context) { | ||
let pad = getParamValue('pad', node, tensorMap, context); | ||
if (pad === 'explicit') { | ||
// This is 1d array, we need to convert it to 2d array | ||
pad = getParamValue('explicitPaddings', node, tensorMap, context); | ||
const explicitPadding = [[0, 0], [0, 0], [0, 0], [0, 0]]; | ||
for (let i = 0; i < 4; i++) { | ||
explicitPadding[i][0] = pad[i * 2]; | ||
explicitPadding[i][1] = pad[i * 2 + 1]; | ||
} | ||
return explicitPadding; | ||
} | ||
return pad; | ||
} | ||
//# sourceMappingURL=utils.js.map |
@@ -91,3 +91,3 @@ /** | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensorArrayId', 'type': 'number' }, | ||
{ 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'index', 'type': 'number' }, | ||
@@ -105,3 +105,3 @@ { 'start': 2, 'name': 'tensor', 'type': 'tensor' }, | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensorArrayId', 'type': 'number' }, | ||
{ 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'index', 'type': 'number' }, | ||
@@ -121,3 +121,3 @@ { 'start': 2, 'name': 'flowIn', 'type': 'number' }, | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensorArrayId', 'type': 'number' }, | ||
{ 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'indices', 'type': 'number[]' }, | ||
@@ -135,3 +135,3 @@ { 'start': 2, 'name': 'flowIn', 'type': 'number' }, | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensorArrayId', 'type': 'number' }, | ||
{ 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'indices', 'type': 'number[]' }, | ||
@@ -147,3 +147,3 @@ { 'start': 2, 'name': 'tensor', 'type': 'tensor' }, | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensorArrayId', 'type': 'number' }, | ||
{ 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'flowIn', 'type': 'number' }, | ||
@@ -164,3 +164,3 @@ ], | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensorArrayId', 'type': 'number' }, | ||
{ 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'tensor', 'type': 'tensor' }, | ||
@@ -176,3 +176,3 @@ { 'start': 2, 'name': 'lengths', 'type': 'number[]' }, | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensorArrayId', 'type': 'number' }, | ||
{ 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'flowIn', 'type': 'number' } | ||
@@ -184,3 +184,3 @@ ] | ||
'category': 'control', | ||
'inputs': [{ 'start': 0, 'name': 'tensorArrayId', 'type': 'number' }] | ||
'inputs': [{ 'start': 0, 'name': 'tensorArrayId', 'type': 'tensor' }] | ||
}, | ||
@@ -232,4 +232,124 @@ { | ||
] | ||
} | ||
}, | ||
{ | ||
'tfOpName': 'TensorListScatter', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensor', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'indices', 'type': 'number[]' }, | ||
{ 'start': 2, 'name': 'elementShape', 'type': 'shape' } | ||
], | ||
'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListScatterV2', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensor', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'indices', 'type': 'number[]' }, | ||
{ 'start': 2, 'name': 'elementShape', 'type': 'shape' }, | ||
{ 'start': 3, 'name': 'numElements', 'type': 'number' }, | ||
], | ||
'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListGather', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensorListId', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'indices', 'type': 'number[]' }, | ||
{ 'start': 2, 'name': 'elementShape', 'type': 'shape' }, | ||
], | ||
'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListGetItem', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensorListId', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'index', 'type': 'number' }, | ||
{ 'start': 2, 'name': 'elementShape', 'type': 'shape' }, | ||
], | ||
'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListSetItem', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensorListId', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'index', 'type': 'number' }, | ||
{ 'start': 2, 'name': 'tensor', 'type': 'tensor' }, | ||
], | ||
'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListReserve', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'elementShape', 'type': 'shape' }, | ||
{ 'start': 1, 'name': 'numElements', 'type': 'number' }, | ||
], | ||
'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListFromTensor', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensor', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'elementShape', 'type': 'shape' } | ||
], | ||
'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListStack', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensorListId', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'elementShape', 'type': 'shape' }, | ||
], | ||
'attrs': [ | ||
{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }, | ||
{ 'tfName': 'num_elements', 'name': 'numElements', 'type': 'dtype' } | ||
] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListSplit', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensor', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'elementShape', 'type': 'shape' }, | ||
{ 'start': 2, 'name': 'lengths', 'type': 'number[]' }, | ||
], | ||
'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListConcat', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensorListId', 'type': 'tensor' }, | ||
], | ||
'attrs': [ | ||
{ 'tfName': 'element_shape', 'name': 'elementShape', 'type': 'shape' }, | ||
{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' } | ||
] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListPopBack', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensorListId', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'elementShape', 'type': 'shape' }, | ||
], | ||
'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListPushBack', | ||
'category': 'control', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'tensorListId', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'tensor', 'type': 'tensor' }, | ||
], | ||
'attrs': [{ 'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype' }] | ||
}, | ||
]; | ||
//# sourceMappingURL=control.js.map |
@@ -1,5 +0,4 @@ | ||
import { OpMapper } from '../types'; | ||
/** | ||
* @license | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Copyright 2020 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
@@ -18,2 +17,3 @@ * you may not use this file except in compliance with the License. | ||
*/ | ||
import { OpMapper } from '../types'; | ||
export declare const json: OpMapper[]; |
/** | ||
* @license | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Copyright 2020 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
@@ -329,4 +329,17 @@ * you may not use this file except in compliance with the License. | ||
], | ||
}, | ||
{ | ||
'tfOpName': 'Dilation2D', | ||
'category': 'convolution', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'x', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'filter', 'type': 'tensor' }, | ||
], | ||
'attrs': [ | ||
{ 'tfName': 'strides', 'name': 'strides', 'type': 'number[]' }, | ||
{ 'tfName': 'rates', 'name': 'dilations', 'type': 'number[]' }, | ||
{ 'tfName': 'padding', 'name': 'pad', 'type': 'string' } | ||
] | ||
} | ||
]; | ||
//# sourceMappingURL=convolution.js.map |
@@ -40,2 +40,26 @@ /** | ||
{ | ||
'tfOpName': 'NonMaxSuppressionV4', | ||
'category': 'dynamic', | ||
'inputs': [ | ||
{ 'start': 0, 'name': 'boxes', 'type': 'tensor' }, | ||
{ 'start': 1, 'name': 'scores', 'type': 'tensor' }, | ||
{ 'start': 2, 'name': 'maxOutputSize', 'type': 'number' }, | ||
{ 'start': 3, 'name': 'iouThreshold', 'type': 'number' }, | ||
{ 'start': 4, 'name': 'scoreThreshold', 'type': 'number' } | ||
], | ||
'attrs': [ | ||
{ 'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true }, { | ||
'tfName': 'T_threshold', | ||
'name': 'threshold', | ||
'type': 'dtype', | ||
'notSupported': true | ||
}, | ||
{ | ||
'tfName': 'pad_to_max_output_size', | ||
'name': 'padToMaxOutputSize', | ||
'type': 'bool' | ||
} | ||
] | ||
}, | ||
{ | ||
'tfOpName': 'NonMaxSuppressionV5', | ||
@@ -42,0 +66,0 @@ 'category': 'dynamic', |
/** @license See the LICENSE file. */ | ||
declare const version = "2.0.1"; | ||
declare const version = "2.1.0"; | ||
export { version }; |
/** @license See the LICENSE file. */ | ||
// This code is auto-generated, do not modify this file! | ||
const version = '2.0.1'; | ||
const version = '2.1.0'; | ||
export { version }; | ||
//# sourceMappingURL=version.js.map |
{ | ||
"name": "@tensorflow/tfjs-converter", | ||
"version": "2.0.1", | ||
"version": "2.1.0", | ||
"description": "Tensorflow model converter for javascript", | ||
@@ -18,3 +18,3 @@ "main": "dist/tf-converter.node.js", | ||
"peerDependencies": { | ||
"@tensorflow/tfjs-core": "2.0.1" | ||
"@tensorflow/tfjs-core": "2.1.0" | ||
}, | ||
@@ -25,4 +25,4 @@ "devDependencies": { | ||
"@rollup/plugin-typescript": "^3.0.0", | ||
"@tensorflow/tfjs-backend-cpu": "2.0.1", | ||
"@tensorflow/tfjs-core": "2.0.1", | ||
"@tensorflow/tfjs-backend-cpu": "2.1.0", | ||
"@tensorflow/tfjs-core": "2.1.0", | ||
"@types/deep-equal": "^1.0.1", | ||
@@ -81,2 +81,3 @@ "@types/jasmine": "~2.8.6", | ||
"gen-json": "ts-node -s ./scripts/gen_json.ts", | ||
"model-summary": "ts-node -s ./tools/model_summary.ts", | ||
"pb2json": "ts-node -s ./tools/pb2json_converter.ts", | ||
@@ -83,0 +84,0 @@ "build-pip-package": "yarn gen-json --test && cd python && ./build-pip-package.sh --test /tmp/tfjs-pips", |
@@ -10,3 +10,3 @@ # Getting started | ||
__Note__: _Session bundle and Frozen model formats have been deprecated in TensorFlow.js 1.0. Please use the TensorFlow.js 0.15.x backend to convert these formats, available in | ||
__Note__: _Session bundle format have been deprecated in TensorFlow.js 1.0. Please use the TensorFlow.js 0.15.x backend to convert session bundle, available in | ||
`tfjs-converter` [0.8.6](https://pypi.org/project/tensorflowjs/0.8.6/)._ | ||
@@ -13,0 +13,0 @@ |
/** | ||
* @license | ||
* Copyright 2020 Google Inc. All Rights Reserved. | ||
* Copyright 2020 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
@@ -5,0 +5,0 @@ * you may not use this file except in compliance with the License. |
@@ -19,2 +19,3 @@ /** | ||
import {TensorArray} from '../executor/tensor_array'; | ||
import {TensorList} from '../executor/tensor_list'; | ||
@@ -33,2 +34,6 @@ export type NamedTensorMap = { | ||
export type TensorListMap = { | ||
[key: number]: TensorList | ||
}; | ||
export interface TensorInfo { | ||
@@ -35,0 +40,0 @@ name: string; |
@@ -19,5 +19,6 @@ /** | ||
import {NamedTensorsMap, TensorArrayMap} from '../data/types'; | ||
import {NamedTensorsMap, TensorArrayMap, TensorListMap} from '../data/types'; | ||
import {TensorArray} from './tensor_array'; | ||
import {TensorList} from './tensor_list'; | ||
import {FunctionExecutor} from './types'; | ||
@@ -48,5 +49,6 @@ | ||
constructor( | ||
public readonly weightMap: NamedTensorsMap, | ||
public readonly tensorArrayMap: TensorArrayMap, | ||
public readonly functionMap: {[key: string]: FunctionExecutor} = {}) { | ||
readonly weightMap: NamedTensorsMap = {}, | ||
readonly tensorArrayMap: TensorArrayMap = {}, | ||
readonly tensorListMap: TensorListMap = {}, | ||
readonly functionMap: {[key: string]: FunctionExecutor} = {}) { | ||
this.generateCurrentContextIds(); | ||
@@ -169,2 +171,20 @@ } | ||
} | ||
addTensorList(tensorList: TensorList) { | ||
this.tensorListMap[tensorList.id] = tensorList; | ||
} | ||
getTensorList(id: number): TensorList { | ||
return this.tensorListMap[id]; | ||
} | ||
dispose() { | ||
for (const key in this.tensorArrayMap) { | ||
this.tensorArrayMap[key].clearAndClose(); | ||
} | ||
for (const key in this.tensorListMap) { | ||
this.tensorListMap[key].clearAndClose(); | ||
} | ||
} | ||
} |
@@ -21,3 +21,3 @@ /** | ||
import {ISignatureDef} from '../data/compiled_api'; | ||
import {NamedTensorsMap, TensorArrayMap, TensorInfo} from '../data/types'; | ||
import {NamedTensorsMap, TensorArrayMap, TensorInfo, TensorListMap} from '../data/types'; | ||
import {getNodeNameAndIndex, getParamValue, getTensor, getTensorsForCurrentContenxt, parseNodeName} from '../operations/executors/utils'; | ||
@@ -197,5 +197,7 @@ import {executeOp} from '../operations/operation_executor'; | ||
const tensorArrayMap: TensorArrayMap = {}; | ||
const tensorListMap: TensorListMap = {}; | ||
return tidy(() => { | ||
const context = new ExecutionContext( | ||
this.weightMap, tensorArrayMap, this.functionExecutorMap); | ||
this.weightMap, tensorArrayMap, tensorListMap, | ||
this.functionExecutorMap); | ||
const tensorsMap: NamedTensorsMap = {...this.weightMap}; | ||
@@ -225,2 +227,6 @@ Object.keys(inputs).forEach(name => { | ||
} | ||
// dispose the context for the root executor | ||
if (this.parent == null) { | ||
context.dispose(); | ||
} | ||
return outputs.map(name => getTensor(name, tensorsMap, context)); | ||
@@ -280,2 +286,3 @@ }); | ||
} | ||
/** | ||
@@ -289,16 +296,38 @@ * Executes the inference for given input tensors in Async fashion. | ||
* array. | ||
* @param disableWarning disable the no dynamic ops warning message, default | ||
* to false | ||
*/ | ||
async executeAsync( | ||
inputs: NamedTensorMap, outputs: string[], | ||
disableWarning = false): Promise<Tensor[]> { | ||
inputs = this.mapInputs(inputs); | ||
this.checkInputs(inputs); | ||
this.checkInputShapeAndType(inputs); | ||
outputs = this.mapOutputs(outputs); | ||
this.checkOutputs(outputs); | ||
const tensorArrayMap: TensorArrayMap = {}; | ||
async executeAsync(inputs: NamedTensorMap, outputs: string[]): | ||
Promise<Tensor[]> { | ||
return this._executeAsync(inputs, outputs); | ||
} | ||
/** | ||
* Executes the inference for given input tensors in Async fashion. | ||
* @param inputs Tensor map for the model inputs, keyed by the input node | ||
* names. | ||
* @param outputs output node name from the Tensorflow model, if no outputs | ||
* are specified, the default outputs of the model would be used. You can | ||
* inspect intermediate nodes of the model by adding them to the outputs | ||
* array. | ||
* @param isFunctionExecution Flag for executing a function. | ||
* @param tensorArrayMap Optional, global TensorArray map by id. Used for | ||
* function execution. | ||
* @param tensorArrayMap Optinal global TensorList map by id. Used for | ||
* function execution. | ||
*/ | ||
private async _executeAsync( | ||
inputs: NamedTensorMap, outputs: string[], isFunctionExecution = false, | ||
tensorArrayMap: TensorArrayMap = {}, | ||
tensorListMap: TensorListMap = {}): Promise<Tensor[]> { | ||
if (!isFunctionExecution) { | ||
inputs = this.mapInputs(inputs); | ||
this.checkInputs(inputs); | ||
this.checkInputShapeAndType(inputs); | ||
outputs = this.mapOutputs(outputs); | ||
this.checkOutputs(outputs); | ||
} | ||
const context = new ExecutionContext( | ||
this.weightMap, tensorArrayMap, this.functionExecutorMap); | ||
this.weightMap, tensorArrayMap, tensorListMap, | ||
this.functionExecutorMap); | ||
// Graph with control flow op requires runtime evaluation of the execution | ||
@@ -308,3 +337,3 @@ // order, while without control flow the execution order is pre-determined | ||
const tensorMap = await this.executeWithControlFlow( | ||
inputs, context, outputs, disableWarning); | ||
inputs, context, outputs, isFunctionExecution); | ||
const results = outputs.map(name => getTensor(name, tensorMap, context)); | ||
@@ -326,6 +355,13 @@ | ||
}); | ||
// dispose the context for the root executor | ||
if (this.parent == null) { | ||
context.dispose(); | ||
} | ||
return results; | ||
} | ||
async executeFunctionAsync(inputs: Tensor[]): Promise<Tensor[]> { | ||
async executeFunctionAsync( | ||
inputs: Tensor[], tensorArrayMap: TensorArrayMap, | ||
tensorListMap: TensorListMap): Promise<Tensor[]> { | ||
const mappedInputs = inputs.reduce((map, tensor, index) => { | ||
@@ -336,3 +372,4 @@ map[this.inputs[index].name] = tensor; | ||
return this.executeAsync(mappedInputs, this.outputNodes, true); | ||
return this._executeAsync( | ||
mappedInputs, this.outputNodes, true, tensorArrayMap, tensorListMap); | ||
} | ||
@@ -344,7 +381,7 @@ /** | ||
* @param context the execution context object for current execution. | ||
* @param disableWarning disable no async op warning | ||
* @param isFunctionExecution Flag for executing a function. | ||
*/ | ||
private async executeWithControlFlow( | ||
inputs: NamedTensorMap, context: ExecutionContext, outputNames: string[], | ||
disableWarning: boolean): Promise<NamedTensorsMap> { | ||
isFunctionExecution: boolean): Promise<NamedTensorsMap> { | ||
const names = Object.keys(inputs); | ||
@@ -378,3 +415,3 @@ const inputNodes = | ||
} | ||
if (dynamicNode == null && !disableWarning) { | ||
if (dynamicNode == null && !isFunctionExecution) { | ||
console.warn( | ||
@@ -381,0 +418,0 @@ `This model execution did not contain any nodes with control flow ` + |
@@ -133,3 +133,3 @@ /** | ||
'Switch', 'Merge', 'Enter', 'Exit', 'NextIteration', 'StatelessIf', | ||
'StatelessWhile' | ||
'StatelessWhile', 'if', 'While' | ||
]; | ||
@@ -136,0 +136,0 @@ const DYNAMIC_SHAPE_OPS = [ |
@@ -18,3 +18,4 @@ /** | ||
import {concat, DataType, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core'; | ||
import {concat, DataType, keep, scalar, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core'; | ||
import {assertShapesMatchAllowUndefinedSize} from './tensor_utils'; | ||
@@ -33,15 +34,17 @@ | ||
export class TensorArray { | ||
private static nextId = 0; | ||
private tensors: TensorWithState[] = []; | ||
private closed_ = false; | ||
readonly id: number; | ||
readonly idTensor: Tensor; | ||
constructor( | ||
public readonly name: string, public readonly dtype: DataType, | ||
private maxSize: number, private elementShape: number[], | ||
public readonly identicalElementShapes: boolean, | ||
public readonly dynamicSize: boolean, | ||
public readonly clearAfterRead: boolean) { | ||
this.id = TensorArray.nextId++; | ||
readonly name: string, readonly dtype: DataType, private maxSize: number, | ||
private elementShape: number[], readonly identicalElementShapes: boolean, | ||
readonly dynamicSize: boolean, readonly clearAfterRead: boolean) { | ||
this.idTensor = scalar(0); | ||
keep(this.idTensor); | ||
} | ||
get id() { | ||
return this.idTensor.id; | ||
} | ||
get closed() { | ||
@@ -52,3 +55,3 @@ return this.closed_; | ||
/** | ||
* Close the current TensorArray. | ||
* Dispose the tensors and idTensor and mark the TensoryArray as closed. | ||
*/ | ||
@@ -59,2 +62,3 @@ clearAndClose() { | ||
this.closed_ = true; | ||
this.idTensor.dispose(); | ||
} | ||
@@ -138,3 +142,3 @@ | ||
if (t && t.read) { | ||
if (t.read) { | ||
throw new Error( | ||
@@ -145,3 +149,3 @@ `TensorArray ${this.name}: Could not write to TensorArray index ${ | ||
if (t && t.written) { | ||
if (t.written) { | ||
throw new Error( | ||
@@ -153,2 +157,3 @@ `TensorArray ${this.name}: Could not write to TensorArray index ${ | ||
t.tensor = tensor; | ||
keep(tensor); | ||
t.written = true; | ||
@@ -155,0 +160,0 @@ |
@@ -18,3 +18,3 @@ /** | ||
import {concat, DataType, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core'; | ||
import {concat, DataType, keep, scalar, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core'; | ||
@@ -39,2 +39,8 @@ import {assertShapesMatchAllowUndefinedSize} from './tensor_utils'; | ||
export class TensorList { | ||
readonly idTensor: Tensor; | ||
maxNumElements: number; | ||
get id() { | ||
return this.idTensor.id; | ||
} | ||
/** | ||
@@ -49,5 +55,21 @@ * | ||
constructor( | ||
public tensors: Tensor[], public elementShape: number[], | ||
public elementDtype: DataType, public maxNumElements = -1) {} | ||
readonly tensors: Tensor[], readonly elementShape: number[], | ||
readonly elementDtype: DataType, maxNumElements = -1) { | ||
if (tensors != null) { | ||
tensors.forEach(tensor => { | ||
if (elementDtype !== tensor.dtype) { | ||
throw new Error(`Invalid data types; op elements ${ | ||
elementDtype}, but list elements ${tensor.dtype}`); | ||
} | ||
assertShapesMatchAllowUndefinedSize( | ||
elementShape, tensor.shape, 'TensorList shape mismatch: '); | ||
keep(tensor); | ||
}); | ||
} | ||
this.idTensor = scalar(0); | ||
this.maxNumElements = maxNumElements; | ||
keep(this.idTensor); | ||
} | ||
/** | ||
@@ -62,2 +84,10 @@ * Get a new TensorList containing a copy of the underlying tensor container. | ||
/** | ||
* Dispose the tensors and idTensor and clear the tensor list. | ||
*/ | ||
clearAndClose() { | ||
this.tensors.forEach(tensor => tensor.dispose()); | ||
this.tensors.length = 0; | ||
this.idTensor.dispose(); | ||
} | ||
/** | ||
* The size of the tensors in the tensor list. | ||
@@ -133,2 +163,3 @@ */ | ||
} | ||
keep(tensor); | ||
this.tensors.push(tensor); | ||
@@ -201,3 +232,3 @@ } | ||
this.elementShape, tensor.shape, 'TensorList shape mismatch: '); | ||
keep(tensor); | ||
this.tensors[elementIndex] = tensor; | ||
@@ -267,3 +298,4 @@ } | ||
*/ | ||
export function fromTensor(tensor: Tensor, elementShape: number[]) { | ||
export function fromTensor( | ||
tensor: Tensor, elementShape: number[], elementDtype: DataType) { | ||
const dtype = tensor.dtype; | ||
@@ -274,3 +306,6 @@ if (tensor.shape.length < 1) { | ||
} | ||
if (tensor.dtype !== elementDtype) { | ||
throw new Error(`Invalid data types; op elements ${ | ||
tensor.dtype}, but list elements ${elementDtype}`); | ||
} | ||
const outputShape = tensor.shape.slice(1); | ||
@@ -280,7 +315,3 @@ assertShapesMatchAllowUndefinedSize( | ||
const tensorList: Tensor[] = []; | ||
for (let i = 0; i < tensor.shape[0]; ++i) { | ||
const tmp = tensor.slice(i, i + 1).reshape(outputShape); | ||
tensorList.push(tmp); | ||
} | ||
const tensorList: Tensor[] = tensor.unstack(); | ||
return new TensorList(tensorList, elementShape, dtype); | ||
@@ -309,3 +340,3 @@ } | ||
tensor: Tensor, indices: number[], elementShape: number[], | ||
numElements: number): TensorList { | ||
numElements?: number): TensorList { | ||
if (indices.length !== tensor.shape[0]) { | ||
@@ -318,3 +349,3 @@ throw new Error(`Expected len(indices) == tensor.shape[0], but saw: ${ | ||
if (numElements !== -1 && maxIndex >= numElements) { | ||
if (numElements != null && numElements !== -1 && maxIndex >= numElements) { | ||
throw new Error( | ||
@@ -363,2 +394,3 @@ `Max index must be < array size (${maxIndex} vs. ${numElements})`); | ||
} | ||
tensor.dispose(); | ||
return tensors; | ||
@@ -365,0 +397,0 @@ }); |
@@ -19,6 +19,13 @@ /** | ||
import {Tensor} from '@tensorflow/tfjs-core'; | ||
import {NamedTensorsMap} from '../data/types'; | ||
import {NamedTensorsMap, TensorArrayMap, TensorListMap} from '../data/types'; | ||
/** | ||
* | ||
*/ | ||
export interface FunctionExecutor { | ||
executeFunctionAsync(inputs: Tensor[]): Promise<Tensor[]>; | ||
executeFunctionAsync( | ||
inputs: Tensor[], tensorArrayMap: TensorArrayMap, | ||
tensorListMap: TensorListMap): Promise<Tensor[]>; | ||
weightMap: NamedTensorsMap; | ||
} |
@@ -24,2 +24,3 @@ /** | ||
import {TensorArray} from '../../executor/tensor_array'; | ||
import {fromTensor, reserve, scatter, split} from '../../executor/tensor_list'; | ||
import {InternalOpAsyncExecutor, Node} from '../types'; | ||
@@ -45,5 +46,7 @@ | ||
if (condValue[0]) { | ||
return context.functionMap[thenFunc].executeFunctionAsync(args); | ||
return context.functionMap[thenFunc].executeFunctionAsync( | ||
args, context.tensorArrayMap, context.tensorListMap); | ||
} else { | ||
return context.functionMap[elseFunc].executeFunctionAsync(args); | ||
return context.functionMap[elseFunc].executeFunctionAsync( | ||
args, context.tensorArrayMap, context.tensorListMap); | ||
} | ||
@@ -59,20 +62,55 @@ } | ||
getParamValue('args', node, tensorMap, context) as tfc.Tensor[]; | ||
const condTensor = | ||
(await context.functionMap[condFunc].executeFunctionAsync(args))[0]; | ||
let condValue = await condTensor.data(); | ||
// Calculate the condition of the loop | ||
const condResult = | ||
(await context.functionMap[condFunc].executeFunctionAsync( | ||
args, context.tensorArrayMap, context.tensorListMap)); | ||
const argIds = args.map(tensor => tensor.id); | ||
let condValue = await condResult[0].data(); | ||
// Dispose the intermediate tensors for condition function | ||
condResult.forEach(tensor => { | ||
if (!tensor.kept && argIds.indexOf(tensor.id) === -1) { | ||
tensor.dispose(); | ||
} | ||
}); | ||
let result: tfc.Tensor[] = args; | ||
while (condValue[0]) { | ||
result = | ||
await context.functionMap[bodyFunc].executeFunctionAsync(result); | ||
const condTensor = | ||
// Record the previous result for intermediate tensor tracking | ||
const origResult = result; | ||
// Execution the body of the loop | ||
result = await context.functionMap[bodyFunc].executeFunctionAsync( | ||
result, context.tensorArrayMap, context.tensorListMap); | ||
const resultIds = result.map(tensor => tensor.id); | ||
// Dispose the intermediate tensor for body function that is not global | ||
// kept, not input/output of the body function | ||
origResult.forEach(tensor => { | ||
if (!tensor.kept && argIds.indexOf(tensor.id) === -1 && | ||
resultIds.indexOf(tensor.id) === -1) { | ||
tensor.dispose(); | ||
} | ||
}); | ||
// Recalcuate the condition of the loop using the latest results. | ||
const condResult = | ||
(await context.functionMap[condFunc].executeFunctionAsync( | ||
result))[0]; | ||
condValue = await condTensor.data(); | ||
result, context.tensorArrayMap, context.tensorListMap)); | ||
condValue = await condResult[0].data(); | ||
// Dispose the intermediate tensors for condition function | ||
condResult.forEach(tensor => { | ||
if (!tensor.kept && argIds.indexOf(tensor.id) === -1 && | ||
resultIds.indexOf(tensor.id) === -1) { | ||
tensor.dispose(); | ||
} | ||
}); | ||
} | ||
return result; | ||
} | ||
case 'LoopCond': | ||
case 'LoopCond': { | ||
return [ | ||
(getParamValue('pred', node, tensorMap, context) as tfc.Tensor).clone() | ||
]; | ||
} | ||
case 'Switch': { | ||
@@ -87,3 +125,3 @@ const pred = | ||
} | ||
case 'Merge': | ||
case 'Merge': { | ||
const inputName = node.inputNames.find( | ||
@@ -93,4 +131,4 @@ name => getTensor(name, tensorMap, context) !== undefined); | ||
undefined; | ||
case 'Enter': | ||
} | ||
case 'Enter': { | ||
const frameId = | ||
@@ -102,4 +140,4 @@ getParamValue('frameName', node, tensorMap, context) as string; | ||
return [data.clone()]; | ||
case 'Exit': | ||
} | ||
case 'Exit': { | ||
const tensor = | ||
@@ -109,4 +147,4 @@ getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; | ||
return [tensor.clone()]; | ||
case 'NextIteration': | ||
} | ||
case 'NextIteration': { | ||
const input = | ||
@@ -116,4 +154,4 @@ getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; | ||
return [input.clone()]; | ||
case 'TensorArrayV3': | ||
} | ||
case 'TensorArrayV3': { | ||
const size = getParamValue('size', node, tensorMap, context) as number; | ||
@@ -136,25 +174,26 @@ const dtype = | ||
context.addTensorArray(tensorArray); | ||
return [scalar(tensorArray.id), scalar(1.0)]; | ||
case 'TensorArrayWriteV3': | ||
const id = | ||
getParamValue('tensorArrayId', node, tensorMap, context) as number; | ||
return [tensorArray.idTensor, scalar(1.0)]; | ||
} | ||
case 'TensorArrayWriteV3': { | ||
const id = getParamValue('tensorArrayId', node, tensorMap, context) as | ||
tfc.Tensor; | ||
const index = getParamValue('index', node, tensorMap, context) as number; | ||
const writeTensor = | ||
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; | ||
const writeTensorArray = context.getTensorArray(id); | ||
const writeTensorArray = context.getTensorArray(id.id); | ||
writeTensorArray.write(index, writeTensor); | ||
return [scalar(1.0)]; | ||
case 'TensorArrayReadV3': | ||
const readId = | ||
getParamValue('tensorArrayId', node, tensorMap, context) as number; | ||
return [writeTensorArray.idTensor]; | ||
} | ||
case 'TensorArrayReadV3': { | ||
const readId = getParamValue('tensorArrayId', node, tensorMap, context) as | ||
tfc.Tensor; | ||
const readIndex = | ||
getParamValue('index', node, tensorMap, context) as number; | ||
const readTensorArray = context.getTensorArray(readId); | ||
const readTensorArray = context.getTensorArray(readId.id); | ||
return [readTensorArray.read(readIndex)]; | ||
case 'TensorArrayGatherV3': | ||
} | ||
case 'TensorArrayGatherV3': { | ||
const gatherId = | ||
getParamValue('tensorArrayId', node, tensorMap, context) as number; | ||
getParamValue('tensorArrayId', node, tensorMap, context) as | ||
tfc.Tensor; | ||
const gatherIndices = | ||
@@ -164,8 +203,9 @@ getParamValue('indices', node, tensorMap, context) as number[]; | ||
getParamValue('dtype', node, tensorMap, context) as tfc.DataType; | ||
const gatherTensorArray = context.getTensorArray(gatherId); | ||
const gatherTensorArray = context.getTensorArray(gatherId.id); | ||
return [gatherTensorArray.gather(gatherIndices, gatherDtype)]; | ||
case 'TensorArrayScatterV3': | ||
} | ||
case 'TensorArrayScatterV3': { | ||
const scatterId = | ||
getParamValue('tensorArrayId', node, tensorMap, context) as number; | ||
getParamValue('tensorArrayId', node, tensorMap, context) as | ||
tfc.Tensor; | ||
const scatterIndices = | ||
@@ -175,17 +215,19 @@ getParamValue('indices', node, tensorMap, context) as number[]; | ||
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; | ||
const scatterTensorArray = context.getTensorArray(scatterId); | ||
const scatterTensorArray = context.getTensorArray(scatterId.id); | ||
scatterTensorArray.scatter(scatterIndices, scatterTensor); | ||
return [scalar(1.0)]; | ||
case 'TensorArrayConcatV3': | ||
return [scatterTensorArray.idTensor]; | ||
} | ||
case 'TensorArrayConcatV3': { | ||
const concatId = | ||
getParamValue('tensorArrayId', node, tensorMap, context) as number; | ||
const concatTensorArray = context.getTensorArray(concatId); | ||
getParamValue('tensorArrayId', node, tensorMap, context) as | ||
tfc.Tensor; | ||
const concatTensorArray = context.getTensorArray(concatId.id); | ||
const concatDtype = | ||
getParamValue('dtype', node, tensorMap, context) as tfc.DataType; | ||
return [concatTensorArray.concat(concatDtype)]; | ||
case 'TensorArraySplitV3': | ||
} | ||
case 'TensorArraySplitV3': { | ||
const splitId = | ||
getParamValue('tensorArrayId', node, tensorMap, context) as number; | ||
getParamValue('tensorArrayId', node, tensorMap, context) as | ||
tfc.Tensor; | ||
const splitTensor = | ||
@@ -195,18 +237,151 @@ getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; | ||
getParamValue('lengths', node, tensorMap, context) as number[]; | ||
const splitTensorArray = context.getTensorArray(splitId); | ||
const splitTensorArray = context.getTensorArray(splitId.id); | ||
splitTensorArray.split(lengths, splitTensor); | ||
return [scalar(1.0)]; | ||
case 'TensorArraySizeV3': | ||
const sizeId = | ||
getParamValue('tensorArrayId', node, tensorMap, context) as number; | ||
const sizeTensorArray = context.getTensorArray(sizeId); | ||
return [splitTensorArray.idTensor]; | ||
} | ||
case 'TensorArraySizeV3': { | ||
const sizeId = getParamValue('tensorArrayId', node, tensorMap, context) as | ||
tfc.Tensor; | ||
const sizeTensorArray = context.getTensorArray(sizeId.id); | ||
return [scalar(sizeTensorArray.size(), 'int32')]; | ||
case 'TensorArrayCloseV3': | ||
} | ||
case 'TensorArrayCloseV3': { | ||
const closeId = | ||
getParamValue('tensorArrayId', node, tensorMap, context) as number; | ||
const closeTensorArray = context.getTensorArray(closeId); | ||
getParamValue('tensorArrayId', node, tensorMap, context) as | ||
tfc.Tensor; | ||
const closeTensorArray = context.getTensorArray(closeId.id); | ||
closeTensorArray.clearAndClose(); | ||
return [scalar(0)]; | ||
return [closeTensorArray.idTensor]; | ||
} | ||
case 'TensorListSetItem': { | ||
const idTensor = | ||
getParamValue('tensorListId', node, tensorMap, context) as tfc.Tensor; | ||
const index = getParamValue('index', node, tensorMap, context) as number; | ||
const writeTensor = | ||
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; | ||
const tensorList = context.getTensorList(idTensor.id); | ||
tensorList.setItem(index, writeTensor); | ||
return [tensorList.idTensor]; | ||
} | ||
case 'TensorListGetItem': { | ||
const idTensor = | ||
getParamValue('tensorListId', node, tensorMap, context) as tfc.Tensor; | ||
const readIndex = | ||
getParamValue('index', node, tensorMap, context) as number; | ||
const elementShape = | ||
getParamValue('elementShape', node, tensorMap, context) as number[]; | ||
const elementDType = | ||
getParamValue('elementDType', node, tensorMap, context) as | ||
tfc.DataType; | ||
const tensorList = context.getTensorList(idTensor.id); | ||
return [tensorList.getItem(readIndex, elementShape, elementDType)]; | ||
} | ||
case 'TensorListScatterV2': | ||
case 'TensorListScatter': { | ||
const scatterIndices = | ||
getParamValue('indices', node, tensorMap, context) as number[]; | ||
const scatterTensor = | ||
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; | ||
const elementShape = | ||
getParamValue('elementShape', node, tensorMap, context) as number[]; | ||
const numElements = | ||
getParamValue('numElements', node, tensorMap, context) as number; | ||
const tensorList = | ||
scatter(scatterTensor, scatterIndices, elementShape, numElements); | ||
context.addTensorList(tensorList); | ||
return [tensorList.idTensor]; | ||
} | ||
case 'TensorListReserve': { | ||
const elementShape = | ||
getParamValue('elementShape', node, tensorMap, context) as number[]; | ||
const elementDtype = | ||
getParamValue('elementDType', node, tensorMap, context) as | ||
tfc.DataType; | ||
const numElements = | ||
getParamValue('numElements', node, tensorMap, context) as number; | ||
const tensorList = reserve(elementShape, elementDtype, numElements); | ||
context.addTensorList(tensorList); | ||
return [tensorList.idTensor]; | ||
} | ||
case 'TensorListGather': { | ||
const gatherId = | ||
getParamValue('tensorListId', node, tensorMap, context) as tfc.Tensor; | ||
const gatherIndices = | ||
getParamValue('indices', node, tensorMap, context) as number[]; | ||
const elementShape = | ||
getParamValue('elementShape', node, tensorMap, context) as number[]; | ||
const elementDtype = | ||
getParamValue('elementDType', node, tensorMap, context) as | ||
tfc.DataType; | ||
const tensorList = context.getTensorList(gatherId.id); | ||
return [tensorList.gather(gatherIndices, elementDtype, elementShape)]; | ||
} | ||
case 'TensorListStack': { | ||
const idTensor = | ||
getParamValue('tensorListId', node, tensorMap, context) as tfc.Tensor; | ||
const elementShape = | ||
getParamValue('elementShape', node, tensorMap, context) as number[]; | ||
const elementDtype = | ||
getParamValue('elementDType', node, tensorMap, context) as | ||
tfc.DataType; | ||
const numElements = | ||
getParamValue('numElements', node, tensorMap, context) as number; | ||
const tensorList = context.getTensorList(idTensor.id); | ||
return [tensorList.stack(elementShape, elementDtype, numElements)]; | ||
} | ||
case 'TensorListFromTensor': { | ||
const tensor = | ||
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; | ||
const elementShape = | ||
getParamValue('elementShape', node, tensorMap, context) as number[]; | ||
const elementDtype = | ||
getParamValue('elementDType', node, tensorMap, context) as | ||
tfc.DataType; | ||
const tensorList = fromTensor(tensor, elementShape, elementDtype); | ||
context.addTensorList(tensorList); | ||
return [tensorList.idTensor]; | ||
} | ||
case 'TensorListConcat': { | ||
const concatId = | ||
getParamValue('tensorListId', node, tensorMap, context) as tfc.Tensor; | ||
const tensorList = context.getTensorList(concatId.id); | ||
const concatDtype = | ||
getParamValue('dtype', node, tensorMap, context) as tfc.DataType; | ||
const elementShape = | ||
getParamValue('elementShape', node, tensorMap, context) as number[]; | ||
return [tensorList.concat(concatDtype, elementShape)]; | ||
} | ||
case 'TensorListPushBack': { | ||
const idTensor = | ||
getParamValue('tensorListId', node, tensorMap, context) as tfc.Tensor; | ||
const writeTensor = | ||
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; | ||
const tensorList = context.getTensorList(idTensor.id); | ||
tensorList.pushBack(writeTensor); | ||
return [tensorList.idTensor]; | ||
} | ||
case 'TensorListPopBack': { | ||
const idTensor = | ||
getParamValue('tensorListId', node, tensorMap, context) as tfc.Tensor; | ||
const elementShape = | ||
getParamValue('elementShape', node, tensorMap, context) as number[]; | ||
const elementDType = | ||
getParamValue('elementDType', node, tensorMap, context) as | ||
tfc.DataType; | ||
const tensorList = context.getTensorList(idTensor.id); | ||
return [tensorList.popBack(elementShape, elementDType)]; | ||
} | ||
case 'TensorListSplit': { | ||
const splitTensor = | ||
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; | ||
const elementShape = | ||
getParamValue('elementShape', node, tensorMap, context) as number[]; | ||
const lengths = | ||
getParamValue('lengths', node, tensorMap, context) as number[]; | ||
const tensorList = split(splitTensor, lengths, elementShape); | ||
context.addTensorList(tensorList); | ||
return [tensorList.idTensor]; | ||
} | ||
default: | ||
@@ -213,0 +388,0 @@ throw TypeError(`Node type ${node.op} is not implemented`); |
@@ -24,3 +24,3 @@ /** | ||
import {getParamValue} from './utils'; | ||
import {getPadding, getParamValue} from './utils'; | ||
@@ -50,3 +50,3 @@ export const executeOp: InternalOpExecutor = (node: Node, | ||
getParamValue('strides', node, tensorMap, context) as number[]; | ||
const pad = getParamValue('pad', node, tensorMap, context); | ||
const pad = getPadding(node, tensorMap, context); | ||
const dataFormat = | ||
@@ -93,3 +93,3 @@ (getParamValue('dataFormat', node, tensorMap, context) as string) | ||
getParamValue('strides', node, tensorMap, context) as number[]; | ||
const pad = getParamValue('pad', node, tensorMap, context); | ||
const pad = getPadding(node, tensorMap, context); | ||
const dataFormat = | ||
@@ -127,3 +127,3 @@ (getParamValue('dataFormat', node, tensorMap, context) as string) | ||
getParamValue('strides', node, tensorMap, context) as number[]; | ||
const pad = getParamValue('pad', node, tensorMap, context); | ||
const pad = getPadding(node, tensorMap, context); | ||
return [tfc.conv2dTranspose( | ||
@@ -139,3 +139,3 @@ getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | | ||
getParamValue('strides', node, tensorMap, context) as number[]; | ||
const pad = getParamValue('pad', node, tensorMap, context); | ||
const pad = getPadding(node, tensorMap, context); | ||
const dilations = | ||
@@ -239,2 +239,25 @@ getParamValue('dilations', node, tensorMap, context) as number[]; | ||
case 'Dilation2D': { | ||
const strides = | ||
getParamValue('strides', node, tensorMap, context) as number[]; | ||
const pad = getParamValue('pad', node, tensorMap, context); | ||
const dilations = | ||
getParamValue('dilations', node, tensorMap, context) as number[]; | ||
// strides: [1, stride_height, stride_width, 1]. | ||
const strideHeight = strides[1]; | ||
const strideWidth = strides[2]; | ||
// dilations: [1, dilation_height, dilation_width, 1]. | ||
const dilationHeight = dilations[1]; | ||
const dilationWidth = dilations[2]; | ||
return [tfc.dilation2d( | ||
getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | | ||
tfc.Tensor4D, | ||
getParamValue('filter', node, tensorMap, context) as tfc.Tensor3D, | ||
[strideHeight, strideWidth], pad as 'valid' | 'same', | ||
[dilationHeight, dilationWidth], 'NHWC' /* dataFormat */)]; | ||
} | ||
default: | ||
@@ -241,0 +264,0 @@ throw TypeError(`Node type ${node.op} is not implemented`); |
/** | ||
* @license | ||
* Copyright 2018 Google Inc. All Rights Reserved. | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
@@ -31,2 +31,3 @@ * you may not use this file except in compliance with the License. | ||
case 'NonMaxSuppressionV5': | ||
case 'NonMaxSuppressionV4': | ||
case 'NonMaxSuppressionV3': | ||
@@ -56,2 +57,14 @@ case 'NonMaxSuppressionV2': { | ||
if (node.op === 'NonMaxSuppressionV4') { | ||
const padToMaxOutputSize = | ||
getParamValue('padToMaxOutputSize', node, tensorMap, context) as | ||
boolean; | ||
const result = await tfc.image.nonMaxSuppressionPaddedAsync( | ||
boxes as tfc.Tensor2D, scores as tfc.Tensor1D, maxOutputSize, | ||
iouThreshold, scoreThreshold, padToMaxOutputSize); | ||
return [result.selectedIndices, result.validOutputs]; | ||
} | ||
return [await tfc.image.nonMaxSuppressionAsync( | ||
@@ -58,0 +71,0 @@ boxes as tfc.Tensor2D, scores as tfc.Tensor1D, maxOutputSize, |
/** | ||
* @license | ||
* Copyright 2018 Google Inc. All Rights Reserved. | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
@@ -5,0 +5,0 @@ * you may not use this file except in compliance with the License. |
/** | ||
* @license | ||
* Copyright 2018 Google Inc. All Rights Reserved. | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
@@ -27,5 +27,5 @@ * you may not use this file except in compliance with the License. | ||
export const executeOp: InternalOpExecutor = (node: Node, | ||
tensorMap: NamedTensorsMap, | ||
context: ExecutionContext): | ||
tfc.Tensor[] => { | ||
tensorMap: NamedTensorsMap, | ||
context: ExecutionContext): | ||
tfc.Tensor[] => { | ||
switch (node.op) { | ||
@@ -32,0 +32,0 @@ case 'ResizeBilinear': { |
@@ -27,5 +27,5 @@ /** | ||
export const executeOp: InternalOpExecutor = (node: Node, | ||
tensorMap: NamedTensorsMap, | ||
context: ExecutionContext): | ||
tfc.Tensor[] => { | ||
tensorMap: NamedTensorsMap, | ||
context: ExecutionContext): | ||
tfc.Tensor[] => { | ||
switch (node.op) { | ||
@@ -81,9 +81,3 @@ case 'ConcatV2': | ||
const tensor = getParamValue('x', node, tensorMap, context) as tfc.Tensor; | ||
if (begin.length === 1 && tensor.shape.length > 1) { | ||
for (let i = 1; i < tensor.shape.length; i++) { | ||
begin.push(0); | ||
end.push(tensor.shape[i]); | ||
strides.push(strides[0]); | ||
} | ||
} | ||
return [tfc.stridedSlice( | ||
@@ -131,5 +125,5 @@ tensor, begin, end, strides, beginMask, endMask, ellipsisMask, | ||
number[]; | ||
return tfc.split( | ||
getParamValue('x', node, tensorMap, context) as tfc.Tensor, | ||
numOrSizeSplits, axis); | ||
const tensor = getParamValue('x', node, tensorMap, context) as tfc.Tensor; | ||
return tfc.split(tensor, numOrSizeSplits, axis); | ||
} | ||
@@ -136,0 +130,0 @@ case 'ScatterNd': { |
@@ -43,2 +43,7 @@ /** | ||
} | ||
export function createShapeAttrFromIndex(inputIndex: number): InputParamValue { | ||
return {inputIndexStart: inputIndex, type: 'shape'}; | ||
} | ||
export function createNumericArrayAttr(value: number[]): ParamValue { | ||
@@ -71,3 +76,3 @@ return {value, type: 'number[]'}; | ||
opMappers.find(mapper => mapper.tfOpName === node.op); | ||
return Object.keys(node.inputParams).every(key => { | ||
const matched = Object.keys(node.inputParams).every(key => { | ||
const value = node.inputParams[key]; | ||
@@ -83,2 +88,7 @@ const def = opMapper.inputs.find(param => param.name === key); | ||
}); | ||
if (!matched) { | ||
console.log('node = ', node); | ||
console.log('opMapper = ', opMapper); | ||
} | ||
return matched; | ||
} |
@@ -24,3 +24,3 @@ /** | ||
import {getParamValue, split} from './utils'; | ||
import {getParamValue} from './utils'; | ||
@@ -58,5 +58,4 @@ export const executeOp: InternalOpExecutor = (node: Node, | ||
getParamValue('x', node, tensorMap, context) as tfc.Tensor, | ||
split( | ||
getParamValue('padding', node, tensorMap, context) as number[], | ||
2) as Array<[number, number]>, | ||
getParamValue('padding', node, tensorMap, context) as | ||
Array<[number, number]>, | ||
getParamValue('constantValue', node, tensorMap, context) as number)]; | ||
@@ -67,4 +66,4 @@ } | ||
getParamValue('blockShape', node, tensorMap, context) as number[]; | ||
const paddings = split( | ||
getParamValue('paddings', node, tensorMap, context) as number[], 2); | ||
const paddings = | ||
getParamValue('paddings', node, tensorMap, context) as number[][]; | ||
return [tfc.spaceToBatchND( | ||
@@ -77,4 +76,4 @@ getParamValue('x', node, tensorMap, context) as tfc.Tensor, | ||
getParamValue('blockShape', node, tensorMap, context) as number[]; | ||
const crops = split( | ||
getParamValue('crops', node, tensorMap, context) as number[], 2); | ||
const crops = | ||
getParamValue('crops', node, tensorMap, context) as number[][]; | ||
return [tfc.batchToSpaceND( | ||
@@ -81,0 +80,0 @@ getParamValue('x', node, tensorMap, context) as tfc.Tensor, |
@@ -43,6 +43,8 @@ /** | ||
} | ||
const data = Array.prototype.slice.call( | ||
getTensor(node.inputNames.slice(start)[0], tensorMap, context) | ||
.dataSync()); | ||
return inputParam.type === 'number' ? data[0] : data; | ||
const tensor = | ||
getTensor(node.inputNames.slice(start)[0], tensorMap, context); | ||
const data = tensor.dataSync(); | ||
return inputParam.type === 'number' ? | ||
data[0] : | ||
tfc.util.toNestedArray(tensor.shape, data); | ||
} | ||
@@ -120,1 +122,19 @@ const attrParam = node.attrParams[paramName]; | ||
} | ||
export function getPadding( | ||
node: Node, tensorMap: NamedTensorsMap, | ||
context: ExecutionContext): ValueType { | ||
let pad = getParamValue('pad', node, tensorMap, context); | ||
if (pad === 'explicit') { | ||
// This is 1d array, we need to convert it to 2d array | ||
pad = getParamValue('explicitPaddings', node, tensorMap, context); | ||
const explicitPadding: [ | ||
[number, number], [number, number], [number, number], [number, number] | ||
] = [[0, 0], [0, 0], [0, 0], [0, 0]]; | ||
for (let i = 0; i < 4; i++) { | ||
explicitPadding[i][0] = (pad as number[])[i * 2]; | ||
explicitPadding[i][1] = (pad as number[])[i * 2 + 1]; | ||
} | ||
return explicitPadding; | ||
} | ||
return pad; | ||
} |
@@ -94,3 +94,3 @@ import {OpMapper} from '../types'; | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensorArrayId', 'type': 'number'}, | ||
{'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'index', 'type': 'number'}, | ||
@@ -108,3 +108,3 @@ {'start': 2, 'name': 'tensor', 'type': 'tensor'}, | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensorArrayId', 'type': 'number'}, | ||
{'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'index', 'type': 'number'}, | ||
@@ -124,3 +124,3 @@ {'start': 2, 'name': 'flowIn', 'type': 'number'}, | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensorArrayId', 'type': 'number'}, | ||
{'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'indices', 'type': 'number[]'}, | ||
@@ -138,3 +138,3 @@ {'start': 2, 'name': 'flowIn', 'type': 'number'}, | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensorArrayId', 'type': 'number'}, | ||
{'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'indices', 'type': 'number[]'}, | ||
@@ -150,3 +150,3 @@ {'start': 2, 'name': 'tensor', 'type': 'tensor'}, | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensorArrayId', 'type': 'number'}, | ||
{'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'flowIn', 'type': 'number'}, | ||
@@ -167,3 +167,3 @@ ], | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensorArrayId', 'type': 'number'}, | ||
{'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'tensor', 'type': 'tensor'}, | ||
@@ -179,3 +179,3 @@ {'start': 2, 'name': 'lengths', 'type': 'number[]'}, | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensorArrayId', 'type': 'number'}, | ||
{'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'flowIn', 'type': 'number'} | ||
@@ -187,3 +187,3 @@ ] | ||
'category': 'control', | ||
'inputs': [{'start': 0, 'name': 'tensorArrayId', 'type': 'number'}] | ||
'inputs': [{'start': 0, 'name': 'tensorArrayId', 'type': 'tensor'}] | ||
}, | ||
@@ -235,3 +235,133 @@ { | ||
] | ||
} | ||
}, | ||
{ | ||
'tfOpName': 'TensorListScatter', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensor', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'indices', 'type': 'number[]'}, | ||
{'start': 2, 'name': 'elementShape', 'type': 'shape'} | ||
], | ||
'attrs': | ||
[{'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype'}] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListScatterV2', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensor', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'indices', 'type': 'number[]'}, | ||
{'start': 2, 'name': 'elementShape', 'type': 'shape'}, | ||
{'start': 3, 'name': 'numElements', 'type': 'number'}, | ||
], | ||
'attrs': | ||
[{'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype'}] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListGather', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensorListId', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'indices', 'type': 'number[]'}, | ||
{'start': 2, 'name': 'elementShape', 'type': 'shape'}, | ||
], | ||
'attrs': | ||
[{'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype'}] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListGetItem', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensorListId', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'index', 'type': 'number'}, | ||
{'start': 2, 'name': 'elementShape', 'type': 'shape'}, | ||
], | ||
'attrs': | ||
[{'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype'}] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListSetItem', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensorListId', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'index', 'type': 'number'}, | ||
{'start': 2, 'name': 'tensor', 'type': 'tensor'}, | ||
], | ||
'attrs': | ||
[{'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype'}] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListReserve', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'name': 'elementShape', 'type': 'shape'}, | ||
{'start': 1, 'name': 'numElements', 'type': 'number'}, | ||
], | ||
'attrs': | ||
[{'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype'}] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListFromTensor', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensor', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'elementShape', 'type': 'shape'} | ||
], | ||
'attrs': | ||
[{'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype'}] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListStack', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensorListId', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'elementShape', 'type': 'shape'}, | ||
], | ||
'attrs': [ | ||
{'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype'}, | ||
{'tfName': 'num_elements', 'name': 'numElements', 'type': 'dtype'} | ||
] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListSplit', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensor', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'elementShape', 'type': 'shape'}, | ||
{'start': 2, 'name': 'lengths', 'type': 'number[]'}, | ||
], | ||
'attrs': | ||
[{'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype'}] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListConcat', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensorListId', 'type': 'tensor'}, | ||
], | ||
'attrs': [ | ||
{'tfName': 'element_shape', 'name': 'elementShape', 'type': 'shape'}, | ||
{'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype'} | ||
] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListPopBack', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensorListId', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'elementShape', 'type': 'shape'}, | ||
], | ||
'attrs': | ||
[{'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype'}] | ||
}, | ||
{ | ||
'tfOpName': 'TensorListPushBack', | ||
'category': 'control', | ||
'inputs': [ | ||
{'start': 0, 'name': 'tensorListId', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'tensor', 'type': 'tensor'}, | ||
], | ||
'attrs': | ||
[{'tfName': 'element_dtype', 'name': 'elementDType', 'type': 'dtype'}] | ||
}, | ||
]; |
@@ -1,6 +0,4 @@ | ||
import {OpMapper} from '../types'; | ||
/** | ||
* @license | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Copyright 2020 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
@@ -20,2 +18,4 @@ * you may not use this file except in compliance with the License. | ||
import {OpMapper} from '../types'; | ||
export const json: OpMapper[] = [ | ||
@@ -333,3 +333,16 @@ { | ||
], | ||
}, | ||
{ | ||
'tfOpName': 'Dilation2D', | ||
'category': 'convolution', | ||
'inputs': [ | ||
{'start': 0, 'name': 'x', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'filter', 'type': 'tensor'}, | ||
], | ||
'attrs': [ | ||
{'tfName': 'strides', 'name': 'strides', 'type': 'number[]'}, | ||
{'tfName': 'rates', 'name': 'dilations', 'type': 'number[]'}, | ||
{'tfName': 'padding', 'name': 'pad', 'type': 'string'} | ||
] | ||
} | ||
]; |
@@ -43,2 +43,26 @@ import {OpMapper} from '../types'; | ||
{ | ||
'tfOpName': 'NonMaxSuppressionV4', | ||
'category': 'dynamic', | ||
'inputs': [ | ||
{'start': 0, 'name': 'boxes', 'type': 'tensor'}, | ||
{'start': 1, 'name': 'scores', 'type': 'tensor'}, | ||
{'start': 2, 'name': 'maxOutputSize', 'type': 'number'}, | ||
{'start': 3, 'name': 'iouThreshold', 'type': 'number'}, | ||
{'start': 4, 'name': 'scoreThreshold', 'type': 'number'} | ||
], | ||
'attrs': [ | ||
{'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true}, { | ||
'tfName': 'T_threshold', | ||
'name': 'threshold', | ||
'type': 'dtype', | ||
'notSupported': true | ||
}, | ||
{ | ||
'tfName': 'pad_to_max_output_size', | ||
'name': 'padToMaxOutputSize', | ||
'type': 'bool' | ||
} | ||
] | ||
}, | ||
{ | ||
'tfOpName': 'NonMaxSuppressionV5', | ||
@@ -45,0 +69,0 @@ 'category': 'dynamic', |
/** @license See the LICENSE file. */ | ||
// This code is auto-generated, do not modify this file! | ||
const version = '2.0.1'; | ||
const version = '2.1.0'; | ||
export {version}; |
@@ -21,3 +21,3 @@ /** | ||
if (argv.length < 3) { | ||
console.log('Usage: ts-node pb2json.ts model_file'); | ||
console.log('Usage: yarn model-summary model_file'); | ||
return; | ||
@@ -29,18 +29,34 @@ } | ||
const rawdata = fs.readFileSync(sourcePath); | ||
const nodes: Array<any> = JSON.parse(rawdata.toString())['modelTopology']['node']; | ||
const model = JSON.parse(rawdata.toString()); | ||
if (model.format !== 'graph-model') { | ||
console.log('This tool only supports TFJS Graph models.'); | ||
return; | ||
} | ||
// tslint:disable-next-line: no-any | ||
let nodes: any[] = model['modelTopology']['node']; | ||
const library = model['modelTopology']['library']; | ||
if (library != null) { | ||
const functions = library['function']; | ||
// tslint:disable-next-line: no-any | ||
if (functions != null) { | ||
functions.forEach((func: any) => nodes = nodes.concat(func['nodeDef'])); | ||
} | ||
} | ||
const opCount: {[key: string]: number} = {}; | ||
for (const opNode of nodes) { | ||
nodes.forEach(opNode => { | ||
let count = 0; | ||
const op = opNode['op']; | ||
if (opCount[op]) { | ||
count = opCount[op]; | ||
} | ||
opCount[op] = count + 1; | ||
} | ||
count = opCount[op]; | ||
} | ||
opCount[op] = count + 1; | ||
}); | ||
console.log(opCount); | ||
console.log('Total ops = ' + nodes.length); | ||
const keys = Object.keys(opCount).sort(); | ||
keys.forEach(key => console.log(`${key}: ${opCount[key]}`)); | ||
console.log(`Total ops = ${nodes.length}`); | ||
} | ||
summarize(process.argv); |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
5839079
42222