@tensorflow/tfjs-converter
Advanced tools
Comparing version 0.1.1 to 0.1.2
@@ -29,2 +29,3 @@ "use strict"; | ||
var graph = require("../operations/op_list/graph.json"); | ||
var image = require("../operations/op_list/image.json"); | ||
var logical = require("../operations/op_list/logical.json"); | ||
@@ -37,3 +38,3 @@ var matrices = require("../operations/op_list/matrices.json"); | ||
var DOC_DIR = './docs/'; | ||
var opMappers = __spread(arithmetic, basicMath, convolution, creation, logical, graph, matrices, normalization, reduction, sliceJoin, transformation); | ||
var opMappers = __spread(arithmetic, basicMath, convolution, creation, logical, image, graph, matrices, normalization, reduction, sliceJoin, transformation); | ||
var output = []; | ||
@@ -49,2 +50,3 @@ output.push('# Supported Tensorflow Ops\n\n'); | ||
generateTable('Normalization', normalization, output); | ||
generateTable('Image', image, output); | ||
generateTable('Reduction', reduction, output); | ||
@@ -51,0 +53,0 @@ generateTable('Slice and Join', sliceJoin, output); |
@@ -7,3 +7,7 @@ import { NamedTensorMap, NamedTensorsMap } from '../data/index'; | ||
private _weightMap; | ||
private placeholders; | ||
private outputs; | ||
weightMap: NamedTensorsMap; | ||
readonly inputNodes: string[]; | ||
readonly outputNodes: string[]; | ||
constructor(graph: operations.Graph); | ||
@@ -13,2 +17,3 @@ private compile(); | ||
dispose(): void; | ||
private checkInput(inputs); | ||
} |
@@ -39,2 +39,4 @@ "use strict"; | ||
this._weightMap = {}; | ||
this.placeholders = graph.placeholders.map(function (node) { return node.name; }); | ||
this.outputs = graph.outputs.map(function (node) { return node.name; }); | ||
this.compile(); | ||
@@ -52,2 +54,16 @@ } | ||
}); | ||
Object.defineProperty(GraphExecutor.prototype, "inputNodes", { | ||
get: function () { | ||
return this.placeholders; | ||
}, | ||
enumerable: true, | ||
configurable: true | ||
}); | ||
Object.defineProperty(GraphExecutor.prototype, "outputNodes", { | ||
get: function () { | ||
return this.outputs; | ||
}, | ||
enumerable: true, | ||
configurable: true | ||
}); | ||
GraphExecutor.prototype.compile = function () { | ||
@@ -69,2 +85,3 @@ var stack = __spread(this.graph.inputs); | ||
var _this = this; | ||
this.checkInput(inputs); | ||
var result = tfjs_core_1.tidy(function () { | ||
@@ -91,2 +108,22 @@ var tensors = _this.compiledOrder.reduce(function (map, node) { | ||
}; | ||
GraphExecutor.prototype.checkInput = function (inputs) { | ||
var _this = this; | ||
var inputKeys = Object.keys(inputs); | ||
var missing = []; | ||
var extra = []; | ||
this.placeholders.forEach(function (name) { | ||
if (inputKeys.indexOf(name) === -1) | ||
missing.push(name); | ||
}); | ||
inputKeys.forEach(function (name) { | ||
if (_this.placeholders.indexOf(name) === -1) | ||
extra.push(name); | ||
}); | ||
if (missing.length > 0) { | ||
throw new Error("Missing input placeholders: " + missing); | ||
} | ||
if (extra.length > 0) { | ||
throw new Error("Extra input tensors: " + extra); | ||
} | ||
}; | ||
return GraphExecutor; | ||
@@ -93,0 +130,0 @@ }()); |
@@ -22,3 +22,3 @@ "use strict"; | ||
case 'pad': { | ||
return [tfc.pad(utils_1.getParamValue('x', node, tensorMap), utils_1.getParamValue('padding', node, tensorMap), utils_1.getParamValue('constantValue', node, tensorMap))]; | ||
return [tfc.pad(utils_1.getParamValue('x', node, tensorMap), utils_1.split(utils_1.getParamValue('padding', node, tensorMap), 2), utils_1.getParamValue('constantValue', node, tensorMap))]; | ||
} | ||
@@ -25,0 +25,0 @@ default: |
@@ -6,1 +6,2 @@ import * as tfc from '@tensorflow/tfjs-core'; | ||
export declare function getTensor(name: string, tensorMap: NamedTensorsMap): tfc.Tensor; | ||
export declare function split(arr: number[], size: number): number[][]; |
@@ -35,2 +35,10 @@ "use strict"; | ||
exports.getTensor = getTensor; | ||
function split(arr, size) { | ||
var res = []; | ||
for (var i = 0; i < arr.length; i += size) { | ||
res.push(arr.slice(i, i + size)); | ||
} | ||
return res; | ||
} | ||
exports.split = split; | ||
//# sourceMappingURL=utils.js.map |
@@ -8,2 +8,3 @@ "use strict"; | ||
var graph = require("./executors/graph_executor"); | ||
var image = require("./executors/image_executor"); | ||
var logical = require("./executors/logical_executor"); | ||
@@ -25,2 +26,4 @@ var matrices = require("./executors/matrices_executor"); | ||
return creation.executeOp(node, tensorMap); | ||
case 'image': | ||
return image.executeOp(node, tensorMap); | ||
case 'graph': | ||
@@ -27,0 +30,0 @@ return graph.executeOp(node, tensorMap); |
@@ -58,2 +58,3 @@ "use strict"; | ||
var withControlFlow = false; | ||
var placeholders = []; | ||
var nodes = tfNodes.reduce(function (map, node) { | ||
@@ -63,2 +64,4 @@ map[node.name] = _this.mapNode(node); | ||
withControlFlow = true; | ||
if (node.op === 'Placeholder') | ||
placeholders.push(map[node.name]); | ||
return map; | ||
@@ -82,3 +85,3 @@ }, {}); | ||
}); | ||
return { nodes: nodes, inputs: inputs, outputs: outputs, withControlFlow: withControlFlow }; | ||
return { nodes: nodes, inputs: inputs, outputs: outputs, placeholders: placeholders, withControlFlow: withControlFlow }; | ||
}; | ||
@@ -85,0 +88,0 @@ OperationMapper.prototype.mapNode = function (node) { |
import { Tensor } from '@tensorflow/tfjs-core'; | ||
export declare type ParamTypes = 'number' | 'string' | 'number[]' | 'bool' | 'shape' | 'tensor' | 'tensors' | 'dtype'; | ||
export declare type Category = 'arithmetic' | 'basic_math' | 'convolution' | 'creation' | 'graph' | 'logical' | 'matrices' | 'normalization' | 'reduction' | 'slice_join' | 'transformation'; | ||
export declare type Category = 'arithmetic' | 'basic_math' | 'convolution' | 'creation' | 'image' | 'graph' | 'logical' | 'matrices' | 'normalization' | 'reduction' | 'slice_join' | 'transformation'; | ||
export interface ParamMapper { | ||
@@ -37,2 +37,3 @@ tfParamName?: string; | ||
}; | ||
placeholders: Node[]; | ||
inputs: Node[]; | ||
@@ -39,0 +40,0 @@ outputs: Node[]; |
@@ -1,2 +0,2 @@ | ||
declare const version = "0.1.1"; | ||
declare const version = "0.1.2"; | ||
export { version }; |
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
var version = '0.1.1'; | ||
var version = '0.1.2'; | ||
exports.version = version; | ||
//# sourceMappingURL=version.js.map |
{ | ||
"name": "@tensorflow/tfjs-converter", | ||
"version": "0.1.1", | ||
"version": "0.1.2", | ||
"description": "Tensorflow model converter for javascript", | ||
@@ -5,0 +5,0 @@ "main": "dist/index.js", |
@@ -7,11 +7,13 @@ [![Build Status](https://travis-ci.org/tensorflow/tfjs-converter.svg?branch=master)](https://travis-ci.org/tensorflow/tfjs-converter) | ||
TensorFlow [SavedModel](https://www.tensorflow.org/programmers_guide/saved_model#overview_of_saving_and_restoring_models) | ||
or [Session Bundle](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/session_bundle/README.md) | ||
into the browser and run inference through [TensorFlow.js](https://js.tensorflow.org). | ||
(Note: TensorFlow has deprecated session bundle format, please switch to SavedModel.) | ||
A 2-step process to import your model: | ||
1. A python pip package to convert a TensorFlow SavedModel to a web friendly format. If you already have a converted model, or are using an already hosted model (e.g. MobileNet), skip this step. | ||
1. A python pip package to convert a TensorFlow SavedModel/Session Bundle to a web friendly format. If you already have a converted model, or are using an already hosted model (e.g. MobileNet), skip this step. | ||
2. [Javascript API](./src/executor/tf_model.ts), for loading and running inference. | ||
## Step 1: Converting a [SavedModel](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md) to a web-friendly format | ||
## Step 1: Converting a [SavedModel](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md) or [Session Bundle](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/session_bundle/README.md) to a web-friendly format | ||
@@ -28,2 +30,4 @@ 1. Install the TensorFlow.js pip package: | ||
SavedModel example: | ||
```bash | ||
@@ -38,5 +42,15 @@ $ tensorflowjs_converter \ | ||
Session bundle model example: | ||
```bash | ||
$ tensorflowjs_converter \ | ||
--input_format=tf_session_bundle \ | ||
--output_node_names='MobilenetV1/Predictions/Reshape_1' \ | ||
/mobilenet/session_bundle \ | ||
/mobilenet/web_model | ||
``` | ||
|Positional Arguments | Description | | ||
|---|---| | ||
|`input_path` | Full path of the saved model directory.| | ||
|`input_path` | Full path of the saved model or session bundle directory.| | ||
|`output_dir` | Path for all output artifacts.| | ||
@@ -47,5 +61,5 @@ | ||
|---|---| | ||
|`--input_format` | The format of input model, use tf_saved_model for SavedModel. | | ||
|`--input_format` | The format of input model, use tf_saved_model for SavedModel and tf_session_bundle for session bundle. | | ||
|`--output_node_names`| The names of the output nodes, separated by commas.| | ||
|`--saved_model_tags` | Tags of the MetaGraphDef to load, in comma separated format. Defaults to `serve`.| | ||
|`--saved_model_tags` | Only applicable to SavedModel conversion, Tags of the MetaGraphDef to load, in comma separated format. Defaults to `serve`.| | ||
@@ -81,3 +95,3 @@ | ||
```typescript | ||
import * as tfc from '@tensorflow/tfjs-core'; | ||
import * as tf from '@tensorflow/tfjs'; | ||
import {loadFrozenModel} from '@tensorflow/tfjs-converter'; | ||
@@ -90,3 +104,3 @@ | ||
const cat = document.getElementById('cat'); | ||
model.execute({input: tfc.fromPixels(cat)}); | ||
model.execute({input: tf.fromPixels(cat)}); | ||
``` | ||
@@ -113,3 +127,17 @@ | ||
## Loading the weights only | ||
If you prefer to load the weights only, you can use follow code snippet. | ||
```typescript | ||
import * as tf from '@tensorflow/tfjs'; | ||
const weightManifestUrl = "https://example.org/model/weights_manifest.json"; | ||
const manifest = await fetch(weightManifestUrl); | ||
this.weightManifest = await manifest.json(); | ||
const weightMap = await tf.loadWeights( | ||
this.weightManifest, "https://example.org/model"); | ||
``` | ||
## FAQ | ||
@@ -116,0 +144,0 @@ |
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
2430879
77
44051
198