@tensorflow/tfjs-layers
Advanced tools
Comparing version 0.7.2 to 0.7.3
@@ -7,2 +7,3 @@ import { Tensor } from '@tensorflow/tfjs-core'; | ||
}; | ||
export declare type YieldEveryOptions = 'auto' | 'batch' | 'epoch' | 'never'; | ||
export declare abstract class BaseCallback { | ||
@@ -34,7 +35,25 @@ validationData: Tensor | Tensor[]; | ||
} | ||
export declare class ModelTrainingYielder { | ||
static readonly SKIP_FIRST_BATCHES: number; | ||
static readonly DECISION_BATCH_COUNT: number; | ||
static readonly THRESHOLD_MILLIS: number; | ||
private yieldEvery; | ||
private batchCount; | ||
private lastYieldBatchCount; | ||
private batchStartMillis; | ||
private batchDurationsMillis; | ||
private autoYieldEveryBatches; | ||
constructor(yieldEvery: YieldEveryOptions); | ||
private resolveOneTensorInLogs(logs); | ||
maybeYieldOnBatch(logs: UnresolvedLogs): Promise<void>; | ||
maybeYieldOnEpoch(): Promise<void>; | ||
} | ||
export declare class BaseLogger extends BaseCallback { | ||
private seen; | ||
private totals; | ||
constructor(); | ||
onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>; | ||
private autoYielder; | ||
private yieldEvery; | ||
constructor(yieldEvery?: YieldEveryOptions); | ||
onTrainBegin(logs?: UnresolvedLogs): Promise<void>; | ||
onEpochBegin(epoch: number): Promise<void>; | ||
onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>; | ||
@@ -41,0 +60,0 @@ onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>; |
@@ -277,10 +277,131 @@ "use strict"; | ||
exports.CallbackList = CallbackList; | ||
var ModelTrainingYielder = (function () { | ||
function ModelTrainingYielder(yieldEvery) { | ||
this.yieldEvery = yieldEvery; | ||
this.batchCount = 0; | ||
this.batchDurationsMillis = []; | ||
this.autoYieldEveryBatches = null; | ||
this.batchStartMillis = tfjs_core_1.util.now(); | ||
} | ||
ModelTrainingYielder.prototype.resolveOneTensorInLogs = function (logs) { | ||
return __awaiter(this, void 0, void 0, function () { | ||
var _a, _b, _i, key, value; | ||
return __generator(this, function (_c) { | ||
switch (_c.label) { | ||
case 0: | ||
_a = []; | ||
for (_b in logs) | ||
_a.push(_b); | ||
_i = 0; | ||
_c.label = 1; | ||
case 1: | ||
if (!(_i < _a.length)) return [3, 4]; | ||
key = _a[_i]; | ||
value = logs[key]; | ||
if (!(typeof value !== 'number')) return [3, 3]; | ||
return [4, value.data()]; | ||
case 2: | ||
_c.sent(); | ||
return [3, 4]; | ||
case 3: | ||
_i++; | ||
return [3, 1]; | ||
case 4: return [2]; | ||
} | ||
}); | ||
}); | ||
}; | ||
ModelTrainingYielder.prototype.maybeYieldOnBatch = function (logs) { | ||
return __awaiter(this, void 0, void 0, function () { | ||
var t, meanBatchDuration; | ||
return __generator(this, function (_a) { | ||
switch (_a.label) { | ||
case 0: | ||
if (!(this.yieldEvery === 'auto')) return [3, 7]; | ||
this.batchCount++; | ||
if (!(this.autoYieldEveryBatches == null)) return [3, 3]; | ||
return [4, this.resolveOneTensorInLogs(logs)]; | ||
case 1: | ||
_a.sent(); | ||
t = tfjs_core_1.util.now(); | ||
return [4, tfjs_core_1.nextFrame()]; | ||
case 2: | ||
_a.sent(); | ||
if (this.batchCount > ModelTrainingYielder.SKIP_FIRST_BATCHES) { | ||
this.batchDurationsMillis.push(t - this.batchStartMillis); | ||
if (this.batchDurationsMillis.length >= | ||
ModelTrainingYielder.DECISION_BATCH_COUNT) { | ||
meanBatchDuration = this.batchDurationsMillis.reduce(function (dur, prev) { return dur + prev; }) / | ||
this.batchDurationsMillis.length; | ||
this.autoYieldEveryBatches = Math.round(ModelTrainingYielder.THRESHOLD_MILLIS / meanBatchDuration); | ||
if (this.autoYieldEveryBatches < 1) { | ||
this.autoYieldEveryBatches = 1; | ||
} | ||
} | ||
} | ||
this.batchStartMillis = tfjs_core_1.util.now(); | ||
this.lastYieldBatchCount = this.batchCount; | ||
return [3, 6]; | ||
case 3: | ||
if (!(this.batchCount - this.lastYieldBatchCount >= | ||
this.autoYieldEveryBatches)) return [3, 6]; | ||
return [4, tfjs_core_1.nextFrame()]; | ||
case 4: | ||
_a.sent(); | ||
return [4, this.resolveOneTensorInLogs(logs)]; | ||
case 5: | ||
_a.sent(); | ||
this.lastYieldBatchCount = this.batchCount; | ||
_a.label = 6; | ||
case 6: return [3, 9]; | ||
case 7: | ||
if (!(this.yieldEvery === 'batch')) return [3, 9]; | ||
return [4, tfjs_core_1.nextFrame()]; | ||
case 8: | ||
_a.sent(); | ||
_a.label = 9; | ||
case 9: return [2]; | ||
} | ||
}); | ||
}); | ||
}; | ||
ModelTrainingYielder.prototype.maybeYieldOnEpoch = function () { | ||
return __awaiter(this, void 0, void 0, function () { | ||
return __generator(this, function (_a) { | ||
switch (_a.label) { | ||
case 0: | ||
if (!(this.yieldEvery === 'epoch')) return [3, 2]; | ||
return [4, tfjs_core_1.nextFrame()]; | ||
case 1: | ||
_a.sent(); | ||
_a.label = 2; | ||
case 2: return [2]; | ||
} | ||
}); | ||
}); | ||
}; | ||
ModelTrainingYielder.SKIP_FIRST_BATCHES = 1; | ||
ModelTrainingYielder.DECISION_BATCH_COUNT = 2; | ||
ModelTrainingYielder.THRESHOLD_MILLIS = 16; | ||
return ModelTrainingYielder; | ||
}()); | ||
exports.ModelTrainingYielder = ModelTrainingYielder; | ||
var BaseLogger = (function (_super) { | ||
__extends(BaseLogger, _super); | ||
function BaseLogger() { | ||
return _super.call(this) || this; | ||
function BaseLogger(yieldEvery) { | ||
var _this = _super.call(this) || this; | ||
_this.yieldEvery = yieldEvery || 'auto'; | ||
return _this; | ||
} | ||
BaseLogger.prototype.onEpochBegin = function (epoch, logs) { | ||
BaseLogger.prototype.onTrainBegin = function (logs) { | ||
return __awaiter(this, void 0, void 0, function () { | ||
return __generator(this, function (_a) { | ||
this.autoYielder = new ModelTrainingYielder(this.yieldEvery); | ||
return [2]; | ||
}); | ||
}); | ||
}; | ||
BaseLogger.prototype.onEpochBegin = function (epoch) { | ||
return __awaiter(this, void 0, void 0, function () { | ||
return __generator(this, function (_a) { | ||
this.seen = 0; | ||
@@ -297,34 +418,39 @@ this.totals = {}; | ||
return __generator(this, function (_a) { | ||
if (logs == null) { | ||
logs = {}; | ||
} | ||
batchSize = logs['size'] == null ? 0 : logs['size']; | ||
this.seen += batchSize; | ||
_loop_1 = function (key) { | ||
var value = logs[key]; | ||
if (typeof value === 'number') { | ||
if (!this_1.totals.hasOwnProperty(key)) { | ||
this_1.totals[key] = 0; | ||
switch (_a.label) { | ||
case 0: return [4, this.autoYielder.maybeYieldOnBatch(logs)]; | ||
case 1: | ||
_a.sent(); | ||
if (logs == null) { | ||
logs = {}; | ||
} | ||
this_1.totals[key] = this_1.totals[key] + value * batchSize; | ||
} | ||
else { | ||
var oldTotalsToDispose = void 0; | ||
if (key in this_1.totals) { | ||
oldTotalsToDispose = this_1.totals[key]; | ||
batchSize = logs['size'] == null ? 0 : logs['size']; | ||
this.seen += batchSize; | ||
_loop_1 = function (key) { | ||
var value = logs[key]; | ||
if (typeof value === 'number') { | ||
if (!this_1.totals.hasOwnProperty(key)) { | ||
this_1.totals[key] = 0; | ||
} | ||
this_1.totals[key] = this_1.totals[key] + value * batchSize; | ||
} | ||
else { | ||
var oldTotalsToDispose = void 0; | ||
if (key in this_1.totals) { | ||
oldTotalsToDispose = this_1.totals[key]; | ||
} | ||
else { | ||
this_1.totals[key] = state_1.getScalar(0); | ||
} | ||
this_1.totals[key] = tfjs_core_1.tidy(function () { return tfjs_core_1.add(_this.totals[key], tfjs_core_1.mul(value, state_1.getScalar(batchSize))); }); | ||
if (oldTotalsToDispose != null) { | ||
oldTotalsToDispose.dispose(); | ||
} | ||
} | ||
}; | ||
this_1 = this; | ||
for (key in logs) { | ||
_loop_1(key); | ||
} | ||
else { | ||
this_1.totals[key] = state_1.getScalar(0); | ||
} | ||
this_1.totals[key] = tfjs_core_1.tidy(function () { return tfjs_core_1.add(_this.totals[key], tfjs_core_1.mul(value, state_1.getScalar(batchSize))); }); | ||
if (oldTotalsToDispose != null) { | ||
oldTotalsToDispose.dispose(); | ||
} | ||
} | ||
}; | ||
this_1 = this; | ||
for (key in logs) { | ||
_loop_1(key); | ||
return [2]; | ||
} | ||
return [2]; | ||
}); | ||
@@ -338,25 +464,30 @@ }); | ||
return __generator(this, function (_b) { | ||
if (logs != null) { | ||
_loop_2 = function (key) { | ||
if (this_2.totals[key] == null) { | ||
return "continue"; | ||
switch (_b.label) { | ||
case 0: return [4, this.autoYielder.maybeYieldOnEpoch()]; | ||
case 1: | ||
_b.sent(); | ||
if (logs != null) { | ||
_loop_2 = function (key) { | ||
if (this_2.totals[key] == null) { | ||
return "continue"; | ||
} | ||
if (typeof this_2.totals[key] === 'number') { | ||
logs[key] = this_2.totals[key] / this_2.seen; | ||
} | ||
else { | ||
tfjs_core_1.tidy(function () { | ||
logs[key] = tfjs_core_1.mul(tfjs_core_1.div(state_1.getScalar(1), state_1.getScalar(_this.seen)), _this.totals[key]); | ||
_this.totals[key].dispose(); | ||
tfjs_core_1.keep(logs[key]); | ||
}); | ||
} | ||
}; | ||
this_2 = this; | ||
for (_i = 0, _a = this.params['metrics']; _i < _a.length; _i++) { | ||
key = _a[_i]; | ||
_loop_2(key); | ||
} | ||
} | ||
if (typeof this_2.totals[key] === 'number') { | ||
logs[key] = this_2.totals[key] / this_2.seen; | ||
} | ||
else { | ||
tfjs_core_1.tidy(function () { | ||
logs[key] = tfjs_core_1.mul(tfjs_core_1.div(state_1.getScalar(1), state_1.getScalar(_this.seen)), _this.totals[key]); | ||
_this.totals[key].dispose(); | ||
tfjs_core_1.keep(logs[key]); | ||
}); | ||
} | ||
}; | ||
this_2 = this; | ||
for (_i = 0, _a = this.params['metrics']; _i < _a.length; _i++) { | ||
key = _a[_i]; | ||
_loop_2(key); | ||
} | ||
return [2]; | ||
} | ||
return [2]; | ||
}); | ||
@@ -363,0 +494,0 @@ }); |
import { Scalar, serialization, Tensor } from '@tensorflow/tfjs-core'; | ||
import { JsonDict, Kwargs, NamedTensorMap, Shape } from '../types'; | ||
import { LayerVariable } from '../variables'; | ||
import { Layer, Node, SymbolicTensor } from './topology'; | ||
import { Layer, Node, SymbolicTensor, DisposeResult } from './topology'; | ||
export declare function loadWeightsFromJson(weightsJSON: JsonDict, layers: Layer[], skipMismatch?: boolean): void; | ||
export declare function loadWeightsFromNamedTensorMap(weights: NamedTensorMap, layers: Layer[]): void; | ||
export declare function loadWeightsFromNamedTensorMap(weights: NamedTensorMap, layers: Layer[], strict?: boolean): void; | ||
export interface ContainerConfig { | ||
@@ -37,6 +37,8 @@ inputs: SymbolicTensor | SymbolicTensor[]; | ||
constructor(config: ContainerConfig); | ||
protected assertNotDisposed(): void; | ||
dispose(): DisposeResult; | ||
readonly trainableWeights: LayerVariable[]; | ||
readonly nonTrainableWeights: LayerVariable[]; | ||
readonly weights: LayerVariable[]; | ||
loadWeights(weightsJSON: JsonDict | NamedTensorMap, skipMismatch?: boolean, isNamedTensorMap?: boolean): void; | ||
loadWeights(weightsJSON: JsonDict | NamedTensorMap, skipMismatch?: boolean, isNamedTensorMap?: boolean, strict?: boolean): void; | ||
private updatedConfig(); | ||
@@ -43,0 +45,0 @@ toJSON(unused?: any, returnString?: boolean): string | JsonDict; |
@@ -95,3 +95,4 @@ "use strict"; | ||
exports.loadWeightsFromJson = loadWeightsFromJson; | ||
function loadWeightsFromNamedTensorMap(weights, layers) { | ||
function loadWeightsFromNamedTensorMap(weights, layers, strict) { | ||
if (strict === void 0) { strict = true; } | ||
var nameToWeight = {}; | ||
@@ -112,13 +113,20 @@ var totalWeightsCount = 0; | ||
for (var name_2 in weights) { | ||
weightValueTuples.push([nameToWeight[name_2], weights[name_2]]); | ||
if (nameToWeight[name_2] != null) { | ||
weightValueTuples.push([nameToWeight[name_2], weights[name_2]]); | ||
} | ||
else if (strict) { | ||
throw new errors_1.ValueError("Provided weight data has no target variable: " + name_2); | ||
} | ||
delete nameToWeight[name_2]; | ||
} | ||
var unsetNames = []; | ||
for (var name_3 in nameToWeight) { | ||
unsetNames.push(name_3); | ||
if (strict) { | ||
var unsetNames = []; | ||
for (var name_3 in nameToWeight) { | ||
unsetNames.push(name_3); | ||
} | ||
if (unsetNames.length > 0) { | ||
throw new errors_1.ValueError(unsetNames.length + " of " + totalWeightsCount + " weights are not set: " + | ||
("" + unsetNames)); | ||
} | ||
} | ||
if (unsetNames.length > 0) { | ||
throw new errors_1.ValueError(unsetNames.length + " of " + totalWeightsCount + " weights are not set: " + | ||
("" + unsetNames)); | ||
} | ||
variables_1.batchSetValue(weightValueTuples); | ||
@@ -381,4 +389,25 @@ } | ||
_this.built = true; | ||
_this._refCount = 1; | ||
return _this; | ||
} | ||
Container.prototype.assertNotDisposed = function () { | ||
if (this._refCount === 0) { | ||
throw new Error("Container '" + this.name + "' is already disposed."); | ||
} | ||
}; | ||
Container.prototype.dispose = function () { | ||
this.assertNotDisposed(); | ||
var result = { | ||
refCountAfterDispose: null, | ||
numDisposedVariables: 0 | ||
}; | ||
if (--this._refCount === 0) { | ||
for (var _i = 0, _a = this.layers; _i < _a.length; _i++) { | ||
var layer = _a[_i]; | ||
result.numDisposedVariables += layer.dispose().numDisposedVariables; | ||
} | ||
} | ||
result.refCountAfterDispose = this._refCount; | ||
return result; | ||
}; | ||
Object.defineProperty(Container.prototype, "trainableWeights", { | ||
@@ -432,7 +461,8 @@ get: function () { | ||
}); | ||
Container.prototype.loadWeights = function (weightsJSON, skipMismatch, isNamedTensorMap) { | ||
Container.prototype.loadWeights = function (weightsJSON, skipMismatch, isNamedTensorMap, strict) { | ||
if (skipMismatch === void 0) { skipMismatch = false; } | ||
if (isNamedTensorMap === void 0) { isNamedTensorMap = false; } | ||
if (strict === void 0) { strict = true; } | ||
if (isNamedTensorMap) { | ||
loadWeightsFromNamedTensorMap(weightsJSON, this.layers); | ||
loadWeightsFromNamedTensorMap(weightsJSON, this.layers, strict); | ||
} | ||
@@ -439,0 +469,0 @@ else { |
import { DataType, serialization, Tensor } from '@tensorflow/tfjs-core'; | ||
import { Kwargs, Shape } from '../types'; | ||
import { Layer, SymbolicTensor } from './topology'; | ||
import { Layer, SymbolicTensor, DisposeResult } from './topology'; | ||
export interface InputLayerConfig { | ||
@@ -17,2 +17,3 @@ inputShape?: Shape; | ||
apply(inputs: Tensor | Tensor[] | SymbolicTensor | SymbolicTensor[], kwargs?: Kwargs): Tensor | Tensor[] | SymbolicTensor; | ||
dispose(): DisposeResult; | ||
getConfig(): serialization.ConfigDict; | ||
@@ -19,0 +20,0 @@ } |
@@ -78,2 +78,8 @@ "use strict"; | ||
}; | ||
InputLayer.prototype.dispose = function () { | ||
return { | ||
refCountAfterDispose: this._refCount, | ||
numDisposedVariables: 0 | ||
}; | ||
}; | ||
InputLayer.prototype.getConfig = function () { | ||
@@ -80,0 +86,0 @@ return { |
@@ -56,2 +56,6 @@ import { DataType, Scalar, serialization, Tensor } from '@tensorflow/tfjs-core'; | ||
} | ||
export interface DisposeResult { | ||
refCountAfterDispose: number; | ||
numDisposedVariables: number; | ||
} | ||
export declare class Node { | ||
@@ -106,2 +110,3 @@ callArgs: Kwargs; | ||
protected _stateful: boolean; | ||
protected _refCount: number | null; | ||
constructor(config: LayerConfig); | ||
@@ -140,3 +145,6 @@ protected static nodeKey(layer: Layer, nodeIndex: number): string; | ||
getConfig(): serialization.ConfigDict; | ||
protected disposeWeights(): number; | ||
protected assertNotDisposed(): void; | ||
dispose(): DisposeResult; | ||
} | ||
export declare function getSourceInputs(tensor: SymbolicTensor, layer?: Layer, nodeIndex?: number): SymbolicTensor[]; |
@@ -155,2 +155,3 @@ "use strict"; | ||
} | ||
_this._refCount = null; | ||
return _this; | ||
@@ -372,2 +373,3 @@ } | ||
kwargs = kwargs || {}; | ||
this.assertNotDisposed(); | ||
var inputsList = generic_utils.toList(inputs); | ||
@@ -407,2 +409,5 @@ var allAreSymbolic = true; | ||
} | ||
if (_this._refCount === null && noneAreSymbolic) { | ||
_this._refCount = 1; | ||
} | ||
} | ||
@@ -444,2 +449,3 @@ _this.assertInputCompatibility(inputs); | ||
_this.addInboundNode(inputs, output, null, null, inputShape, outputShape, kwargs); | ||
_this._refCount++; | ||
if (_this.activityRegularizer != null) { | ||
@@ -657,2 +663,30 @@ throw new errors_1.NotImplementedError('Layer invocation in the presence of activity ' + | ||
}; | ||
Layer.prototype.disposeWeights = function () { | ||
this.weights.forEach(function (weight) { return weight.dispose(); }); | ||
return this.weights.length; | ||
}; | ||
Layer.prototype.assertNotDisposed = function () { | ||
if (this._refCount === 0) { | ||
throw new Error("Layer '" + this.name + "' is already disposed."); | ||
} | ||
}; | ||
Layer.prototype.dispose = function () { | ||
if (!this.built) { | ||
throw new Error("Cannot dispose Layer " + this.name + " because it has not been " + | ||
"built yet."); | ||
} | ||
if (this._refCount === null) { | ||
throw new Error("Cannot dispose Layer " + this.name + " because it has not been used " + | ||
"yet."); | ||
} | ||
this.assertNotDisposed(); | ||
var numDisposedVariables = 0; | ||
if (--this._refCount === 0) { | ||
numDisposedVariables = this.disposeWeights(); | ||
} | ||
return { | ||
refCountAfterDispose: this._refCount, | ||
numDisposedVariables: numDisposedVariables | ||
}; | ||
}; | ||
return Layer; | ||
@@ -659,0 +693,0 @@ }(tfjs_core_1.serialization.Serializable)); |
import * as tfc from '@tensorflow/tfjs-core'; | ||
import { io, ModelPredictConfig, Optimizer, Scalar, Tensor, Tensor1D } from '@tensorflow/tfjs-core'; | ||
import { BaseCallback, CustomCallbackConfig, History } from '../base_callbacks'; | ||
import { BaseCallback, CustomCallbackConfig, History, YieldEveryOptions } from '../base_callbacks'; | ||
import { LossOrMetricFn, NamedTensorMap, Shape } from '../types'; | ||
@@ -48,3 +48,3 @@ import { Container, ContainerConfig } from './container'; | ||
validationSteps?: number; | ||
yieldEvery?: 'batch' | 'epoch' | 'never'; | ||
yieldEvery?: YieldEveryOptions; | ||
} | ||
@@ -98,3 +98,3 @@ export interface ModelCompileConfig { | ||
}, checkBatchAxis?: boolean, batchSize?: number): [Tensor[], Tensor[], Tensor[]]; | ||
private fitLoop(f, ins, outLabels?, batchSize?, epochs?, verbose?, callbacks?, valF?, valIns?, shuffle?, callbackMetrics?, initialEpoch?, stepsPerEpoch?, validationSteps?); | ||
private fitLoop(f, ins, outLabels?, batchSize?, epochs?, verbose?, callbacks?, valF?, valIns?, shuffle?, callbackMetrics?, initialEpoch?, stepsPerEpoch?, validationSteps?, yieldEvery?); | ||
private testLoop(f, ins, batchSize?, verbose?, steps?); | ||
@@ -101,0 +101,0 @@ private getDedupedMetricsNames(); |
@@ -720,6 +720,6 @@ "use strict"; | ||
}; | ||
Model.prototype.fitLoop = function (f, ins, outLabels, batchSize, epochs, verbose, callbacks, valF, valIns, shuffle, callbackMetrics, initialEpoch, stepsPerEpoch, validationSteps) { | ||
Model.prototype.fitLoop = function (f, ins, outLabels, batchSize, epochs, verbose, callbacks, valF, valIns, shuffle, callbackMetrics, initialEpoch, stepsPerEpoch, validationSteps, yieldEvery) { | ||
return __awaiter(this, void 0, void 0, function () { | ||
var _this = this; | ||
var doValidation, numTrainSamples, indexArray, callbackList, _loop_4, this_1, epoch, state_2; | ||
var doValidation, numTrainSamples, indexArray, baseLogger, callbackList, _loop_4, this_1, epoch, state_2; | ||
return __generator(this, function (_a) { | ||
@@ -756,7 +756,8 @@ switch (_a.label) { | ||
this.history = new base_callbacks_1.History(); | ||
baseLogger = new base_callbacks_1.BaseLogger(yieldEvery); | ||
if (callbacks == null) { | ||
callbacks = [new base_callbacks_1.BaseLogger()]; | ||
callbacks = [baseLogger]; | ||
} | ||
else { | ||
callbacks = [new base_callbacks_1.BaseLogger()].concat(callbacks); | ||
callbacks = [baseLogger].concat(callbacks); | ||
} | ||
@@ -1093,3 +1094,3 @@ callbacks = callbacks.concat([this.history]); | ||
callbacks = base_callbacks_1.standardizeCallbacks(config.callbacks); | ||
return [4, this.fitLoop(trainFunction, ins, outLabels, batchSize, config.epochs, config.verbose, callbacks, valFunction, valIns, config.shuffle, callbackMetrics, config.initialEpoch, null, null)]; | ||
return [4, this.fitLoop(trainFunction, ins, outLabels, batchSize, config.epochs, config.verbose, callbacks, valFunction, valIns, config.shuffle, callbackMetrics, config.initialEpoch, null, null, config.yieldEvery)]; | ||
case 1: | ||
@@ -1096,0 +1097,0 @@ out = _a.sent(); |
@@ -7,3 +7,3 @@ import { InputLayerConfig } from './engine/input_layer'; | ||
import { DepthwiseConv2DLayerConfig } from './layers/convolutional_depthwise'; | ||
import { ActivationLayerConfig, DenseLayerConfig, DropoutLayerConfig, RepeatVectorLayerConfig, ReshapeLayerConfig } from './layers/core'; | ||
import { ActivationLayerConfig, DenseLayerConfig, DropoutLayerConfig, PermuteLayerConfig, RepeatVectorLayerConfig, ReshapeLayerConfig } from './layers/core'; | ||
import { EmbeddingLayerConfig } from './layers/embeddings'; | ||
@@ -14,3 +14,3 @@ import { ConcatenateLayerConfig } from './layers/merge'; | ||
import { GlobalPooling2DLayerConfig, Pooling1DLayerConfig, Pooling2DLayerConfig } from './layers/pooling'; | ||
import { GRUCellLayerConfig, GRULayerConfig, LSTMCellLayerConfig, LSTMLayerConfig, RNNCell, RNNLayerConfig, SimpleRNNCellLayerConfig, SimpleRNNLayerConfig, StackedRNNCellsConfig } from './layers/recurrent'; | ||
import { GRUCellLayerConfig, GRULayerConfig, LSTMCellLayerConfig, LSTMLayerConfig, RNN, RNNCell, RNNLayerConfig, SimpleRNNCellLayerConfig, SimpleRNNLayerConfig, StackedRNNCellsConfig } from './layers/recurrent'; | ||
import { BidirectionalLayerConfig, Wrapper, WrapperLayerConfig } from './layers/wrappers'; | ||
@@ -35,2 +35,3 @@ export declare function inputLayer(config: InputLayerConfig): Layer; | ||
export declare function reshape(config: ReshapeLayerConfig): Layer; | ||
export declare function permute(config: PermuteLayerConfig): Layer; | ||
export declare function embedding(config: EmbeddingLayerConfig): Layer; | ||
@@ -71,2 +72,2 @@ export declare function add(config?: LayerConfig): Layer; | ||
export declare const maxPool2d: typeof maxPooling2d; | ||
export { Layer, input }; | ||
export { Layer, RNN, RNNCell, input }; |
@@ -18,2 +18,4 @@ "use strict"; | ||
var recurrent_1 = require("./layers/recurrent"); | ||
exports.RNN = recurrent_1.RNN; | ||
exports.RNNCell = recurrent_1.RNNCell; | ||
var wrappers_1 = require("./layers/wrappers"); | ||
@@ -92,2 +94,6 @@ function inputLayer(config) { | ||
exports.reshape = reshape; | ||
function permute(config) { | ||
return new core_1.Permute(config); | ||
} | ||
exports.permute = permute; | ||
function embedding(config) { | ||
@@ -94,0 +100,0 @@ return new embeddings_1.Embedding(config); |
@@ -9,3 +9,3 @@ import { io } from '@tensorflow/tfjs-core'; | ||
export declare function sequential(config?: SequentialConfig): Sequential; | ||
export declare function loadModel(pathOrIOHandler: string | io.IOHandler): Promise<Model>; | ||
export declare function loadModel(pathOrIOHandler: string | io.IOHandler, strict?: boolean): Promise<Model>; | ||
export declare function input(config: InputConfig): SymbolicTensor; |
@@ -14,4 +14,5 @@ "use strict"; | ||
exports.sequential = sequential; | ||
function loadModel(pathOrIOHandler) { | ||
return models_1.loadModelInternal(pathOrIOHandler); | ||
function loadModel(pathOrIOHandler, strict) { | ||
if (strict === void 0) { strict = true; } | ||
return models_1.loadModelInternal(pathOrIOHandler, strict); | ||
} | ||
@@ -18,0 +19,0 @@ exports.loadModel = loadModel; |
@@ -6,3 +6,3 @@ import * as constraints from './exports_constraints'; | ||
import * as regularizers from './exports_regularizers'; | ||
export { CallbackList, CustomCallback, CustomCallbackConfig } from './base_callbacks'; | ||
export { CallbackList, CustomCallback, CustomCallbackConfig, History } from './base_callbacks'; | ||
export { Callback } from './callbacks'; | ||
@@ -9,0 +9,0 @@ export { SymbolicTensor } from './engine/topology'; |
@@ -16,2 +16,3 @@ "use strict"; | ||
exports.CustomCallback = base_callbacks_1.CustomCallback; | ||
exports.History = base_callbacks_1.History; | ||
var callbacks_1 = require("./callbacks"); | ||
@@ -18,0 +19,0 @@ exports.Callback = callbacks_1.Callback; |
@@ -98,1 +98,13 @@ import { serialization, Tensor } from '@tensorflow/tfjs-core'; | ||
} | ||
export interface PermuteLayerConfig extends LayerConfig { | ||
dims: number[]; | ||
} | ||
export declare class Permute extends Layer { | ||
static className: string; | ||
readonly dims: number[]; | ||
private readonly dimsIncludingBatch; | ||
constructor(config: PermuteLayerConfig); | ||
computeOutputShape(inputShape: Shape | Shape[]): Shape | Shape[]; | ||
call(inputs: Tensor | Tensor[], kwargs: Kwargs): Tensor | Tensor[]; | ||
getConfig(): serialization.ConfigDict; | ||
} |
@@ -15,10 +15,10 @@ "use strict"; | ||
var activations_1 = require("../activations"); | ||
var state_1 = require("../backend/state"); | ||
var K = require("../backend/tfjs_backend"); | ||
var constraints_1 = require("../constraints"); | ||
var topology_1 = require("../engine/topology"); | ||
var state_1 = require("../backend/state"); | ||
var errors_1 = require("../errors"); | ||
var initializers_1 = require("../initializers"); | ||
var regularizers_1 = require("../regularizers"); | ||
var math_utils = require("../utils/math_utils"); | ||
var math_utils_1 = require("../utils/math_utils"); | ||
var types_utils_1 = require("../utils/types_utils"); | ||
@@ -193,3 +193,3 @@ var Dropout = (function (_super) { | ||
} | ||
return [inputShape[0], math_utils.arrayProd(inputShape, 1)]; | ||
return [inputShape[0], math_utils_1.arrayProd(inputShape, 1)]; | ||
}; | ||
@@ -300,3 +300,3 @@ Flatten.prototype.call = function (inputs, kwargs) { | ||
} | ||
var originalSize = math_utils.arrayProd(inputShape); | ||
var originalSize = math_utils_1.arrayProd(inputShape); | ||
if (unknown !== null) { | ||
@@ -351,2 +351,48 @@ if (known === 0 || originalSize % known !== 0) { | ||
tfjs_core_1.serialization.SerializationMap.register(Reshape); | ||
var Permute = (function (_super) { | ||
__extends(Permute, _super); | ||
function Permute(config) { | ||
var _this = _super.call(this, config) || this; | ||
if (config.dims == null) { | ||
throw new Error('Required configuration field `dims` is missing during Permute ' + | ||
'constructor call.'); | ||
} | ||
if (!Array.isArray(config.dims)) { | ||
throw new Error('Permute constructor requires `dims` to be an Array, but received ' + | ||
(config.dims + " instead.")); | ||
} | ||
var expectedSortedIndices = math_utils_1.range(1, config.dims.length + 1); | ||
if (!tfjs_core_1.util.arraysEqual(config.dims.slice().sort(), expectedSortedIndices)) { | ||
throw new Error('Invalid permutation `dims`: ' + JSON.stringify(config.dims) + | ||
' `dims` must contain consecutive integers starting from 1.'); | ||
} | ||
_this.dims = config.dims; | ||
_this.dimsIncludingBatch = [0].concat(_this.dims); | ||
_this.inputSpec = [new topology_1.InputSpec({ ndim: _this.dims.length + 1 })]; | ||
return _this; | ||
} | ||
Permute.prototype.computeOutputShape = function (inputShape) { | ||
inputShape = types_utils_1.getExactlyOneShape(inputShape); | ||
var outputShape = inputShape.slice(); | ||
this.dims.forEach(function (dim, i) { | ||
outputShape[i + 1] = inputShape[dim]; | ||
}); | ||
return outputShape; | ||
}; | ||
Permute.prototype.call = function (inputs, kwargs) { | ||
return tfjs_core_1.transpose(types_utils_1.getExactlyOneTensor(inputs), this.dimsIncludingBatch); | ||
}; | ||
Permute.prototype.getConfig = function () { | ||
var config = { | ||
dims: this.dims, | ||
}; | ||
var baseConfig = _super.prototype.getConfig.call(this); | ||
Object.assign(config, baseConfig); | ||
return config; | ||
}; | ||
Permute.className = 'Permute'; | ||
return Permute; | ||
}(topology_1.Layer)); | ||
exports.Permute = Permute; | ||
tfjs_core_1.serialization.SerializationMap.register(Permute); | ||
//# sourceMappingURL=core.js.map |
@@ -17,4 +17,4 @@ import { io, Scalar, serialization, Tensor } from '@tensorflow/tfjs-core'; | ||
} | ||
export declare function loadModelInternal(pathOrIOHandler: string | io.IOHandler): Promise<Model>; | ||
export declare function loadModelFromIOHandler(handler: io.IOHandler, customObjects?: serialization.ConfigDict): Promise<Model>; | ||
export declare function loadModelInternal(pathOrIOHandler: string | io.IOHandler, strict?: boolean): Promise<Model>; | ||
export declare function loadModelFromIOHandler(handler: io.IOHandler, customObjects?: serialization.ConfigDict, strict?: boolean): Promise<Model>; | ||
export interface SequentialConfig { | ||
@@ -21,0 +21,0 @@ layers?: Layer[]; |
@@ -90,3 +90,4 @@ "use strict"; | ||
exports.modelFromJSON = modelFromJSON; | ||
function loadModelInternal(pathOrIOHandler) { | ||
function loadModelInternal(pathOrIOHandler, strict) { | ||
if (strict === void 0) { strict = true; } | ||
return __awaiter(this, void 0, void 0, function () { | ||
@@ -106,3 +107,3 @@ var handlers; | ||
} | ||
return [2, loadModelFromIOHandler(pathOrIOHandler)]; | ||
return [2, loadModelFromIOHandler(pathOrIOHandler, undefined, strict)]; | ||
}); | ||
@@ -112,3 +113,4 @@ }); | ||
exports.loadModelInternal = loadModelInternal; | ||
function loadModelFromIOHandler(handler, customObjects) { | ||
function loadModelFromIOHandler(handler, customObjects, strict) { | ||
if (strict === void 0) { strict = true; } | ||
return __awaiter(this, void 0, void 0, function () { | ||
@@ -138,3 +140,3 @@ var artifacts, modelTopology, model, skipMismatch, isNamedTensorMap; | ||
isNamedTensorMap = true; | ||
model.loadWeights(tfjs_core_1.io.decodeWeights(artifacts.weightData, artifacts.weightSpecs), skipMismatch, isNamedTensorMap); | ||
model.loadWeights(tfjs_core_1.io.decodeWeights(artifacts.weightData, artifacts.weightSpecs), skipMismatch, isNamedTensorMap, strict); | ||
} | ||
@@ -141,0 +143,0 @@ return [2, model]; |
@@ -41,2 +41,11 @@ "use strict"; | ||
model.checkTrainableWeightsConsistency(); | ||
var trainableCount = countTrainableParams(model); | ||
var nonTrainableCount = variable_utils_1.countParamsInWeights(model.nonTrainableWeights); | ||
printFn("Total params: " + (trainableCount + nonTrainableCount)); | ||
printFn("Trainable params: " + trainableCount); | ||
printFn("Non-trainable params: " + nonTrainableCount); | ||
printFn('_'.repeat(lineLength)); | ||
} | ||
exports.printSummary = printSummary; | ||
function countTrainableParams(model) { | ||
var trainableCount; | ||
@@ -50,9 +59,4 @@ if (model.collectedTrainableWeights != null) { | ||
} | ||
var nonTrainableCount = variable_utils_1.countParamsInWeights(model.nonTrainableWeights); | ||
printFn("Total params: " + (trainableCount + nonTrainableCount)); | ||
printFn("Trainable params: " + trainableCount); | ||
printFn("Non-trainable params: " + nonTrainableCount); | ||
printFn('_'.repeat(lineLength)); | ||
return trainableCount; | ||
} | ||
exports.printSummary = printSummary; | ||
function isModelSequentialLike(model) { | ||
@@ -59,0 +63,0 @@ var sequentialLike = true; |
@@ -17,2 +17,4 @@ import * as tfc from '@tensorflow/tfjs-core'; | ||
write(newVal: Tensor): this; | ||
dispose(): void; | ||
protected assertNotDisposed(): void; | ||
} | ||
@@ -19,0 +21,0 @@ export declare function variable(x: Tensor, dtype?: DataType, name?: string, constraint?: Constraint): LayerVariable; |
@@ -26,5 +26,7 @@ "use strict"; | ||
LayerVariable.prototype.read = function () { | ||
this.assertNotDisposed(); | ||
return this.val; | ||
}; | ||
LayerVariable.prototype.write = function (newVal) { | ||
this.assertNotDisposed(); | ||
checkShapesMatch(this.val, newVal); | ||
@@ -37,2 +39,11 @@ this.val.assign(newVal); | ||
}; | ||
LayerVariable.prototype.dispose = function () { | ||
this.assertNotDisposed(); | ||
this.val.dispose(); | ||
}; | ||
LayerVariable.prototype.assertNotDisposed = function () { | ||
if (this.val.isDisposed) { | ||
throw new Error("LayersVariable " + this.name + " is already disposed."); | ||
} | ||
}; | ||
return LayerVariable; | ||
@@ -39,0 +50,0 @@ }()); |
@@ -1,2 +0,2 @@ | ||
declare const version = "0.7.2"; | ||
declare const version = "0.7.3"; | ||
export { version }; |
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
var version = '0.7.2'; | ||
var version = '0.7.3'; | ||
exports.version = version; | ||
//# sourceMappingURL=version.js.map |
{ | ||
"name": "@tensorflow/tfjs-layers", | ||
"version": "0.7.2", | ||
"version": "0.7.3", | ||
"description": "TensorFlow layers API in JavaScript", | ||
@@ -13,3 +13,3 @@ "private": false, | ||
"devDependencies": { | ||
"@tensorflow/tfjs-core": "~0.12.8", | ||
"@tensorflow/tfjs-core": "~0.12.10", | ||
"@types/jasmine": "~2.5.53", | ||
@@ -31,3 +31,3 @@ "clang-format": "~1.2.2", | ||
"rollup-plugin-uglify": "~3.0.0", | ||
"tslint": "~5.6.0", | ||
"tslint": "~5.11.0", | ||
"tslint-no-circular-imports": "^0.5.0", | ||
@@ -47,7 +47,7 @@ "typescript": "2.8.3", | ||
"test-travis": "karma start --browsers='bs_firefox_mac,bs_chrome_mac' --singleRun --reporters='dots,karma-typescript'", | ||
"lint": "tslint -p . --type-check -t verbose" | ||
"lint": "tslint -p . -t verbose" | ||
}, | ||
"peerDependencies": { | ||
"@tensorflow/tfjs-core": "~0.12.8" | ||
"@tensorflow/tfjs-core": "~0.12.10" | ||
} | ||
} |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
5490709
26604