@tensorflow/tfjs-layers
Advanced tools
Comparing version 0.0.4 to 0.0.5
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
var common_1 = require("../common"); | ||
var _epsilon = 1e-7; | ||
@@ -14,4 +13,4 @@ function epsilon() { | ||
function imageDataFormat() { | ||
return common_1.DataFormat.CHANNEL_LAST; | ||
return 'channelLast'; | ||
} | ||
exports.imageDataFormat = imageDataFormat; |
@@ -93,7 +93,7 @@ import { Scalar, Tensor, Tensor1D } from '@tensorflow/tfjs-core'; | ||
export declare function l2Normalize(x: Tensor, axis?: number): Tensor; | ||
export declare function conv1dWithBias(x: Tensor, kernel: Tensor, bias: Tensor, strides?: number, padding?: PaddingMode, dataFormat?: DataFormat, dilationRate?: number): Tensor; | ||
export declare function conv1d(x: Tensor, kernel: Tensor, strides?: number, padding?: PaddingMode, dataFormat?: DataFormat, dilationRate?: number): Tensor; | ||
export declare function conv2d(x: Tensor, kernel: Tensor, strides?: number[], padding?: PaddingMode, dataFormat?: DataFormat, dilationRate?: [number, number]): Tensor; | ||
export declare function conv2dWithBias(x: Tensor, kernel: Tensor, bias: Tensor, strides?: number[], padding?: PaddingMode, dataFormat?: DataFormat, dilationRate?: [number, number]): Tensor; | ||
export declare function depthwiseConv2d(x: Tensor, depthwiseKernel: Tensor, strides?: [number, number], padding?: PaddingMode, dataFormat?: DataFormat, dilationRate?: [number, number]): Tensor; | ||
export declare function conv1dWithBias(x: Tensor, kernel: Tensor, bias: Tensor, strides?: number, padding?: string, dataFormat?: DataFormat, dilationRate?: number): Tensor; | ||
export declare function conv1d(x: Tensor, kernel: Tensor, strides?: number, padding?: string, dataFormat?: DataFormat, dilationRate?: number): Tensor; | ||
export declare function conv2d(x: Tensor, kernel: Tensor, strides?: number[], padding?: string, dataFormat?: DataFormat, dilationRate?: [number, number]): Tensor; | ||
export declare function conv2dWithBias(x: Tensor, kernel: Tensor, bias: Tensor, strides?: number[], padding?: string, dataFormat?: DataFormat, dilationRate?: [number, number]): Tensor; | ||
export declare function depthwiseConv2d(x: Tensor, depthwiseKernel: Tensor, strides?: [number, number], padding?: string, dataFormat?: DataFormat, dilationRate?: [number, number]): Tensor; | ||
export declare function pool2d(x: Tensor, poolSize: [number, number], strides?: [number, number], padding?: PaddingMode, dataFormat?: DataFormat, poolMode?: PoolMode): Tensor; | ||
@@ -100,0 +100,0 @@ export declare function nameScope<T>(name: string, fn: () => T): T; |
@@ -512,2 +512,3 @@ "use strict"; | ||
function biasAdd(x, bias, dataFormat) { | ||
common_1.checkDataFormat(dataFormat); | ||
if (ndim(bias) !== 1 && ndim(bias) !== ndim(x)) { | ||
@@ -573,3 +574,4 @@ throw new errors_1.ValueError('Unexpected bias dimensions: ' + ndim(bias) + | ||
function preprocessConv2DInput(x, dataFormat) { | ||
if (dataFormat === common_1.DataFormat.CHANNEL_FIRST) { | ||
common_1.checkDataFormat(dataFormat); | ||
if (dataFormat === 'channelFirst') { | ||
return tfc.transpose(x, [0, 2, 3, 1]); | ||
@@ -583,3 +585,3 @@ } | ||
if (strides === void 0) { strides = 1; } | ||
if (padding === void 0) { padding = common_1.PaddingMode.VALID; } | ||
if (padding === void 0) { padding = 'valid'; } | ||
if (dilationRate === void 0) { dilationRate = 1; } | ||
@@ -589,2 +591,3 @@ if (dataFormat == null) { | ||
} | ||
common_1.checkDataFormat(dataFormat); | ||
if (dilationRate !== 1) { | ||
@@ -606,10 +609,10 @@ throw new errors_1.NotImplementedError("dilationRate = " + dilationRate + " is not implemented for 1D " + | ||
} | ||
if (dataFormat === common_1.DataFormat.CHANNEL_FIRST) { | ||
if (dataFormat === 'channelFirst') { | ||
x = transpose(x, [0, 2, 1]); | ||
} | ||
if (padding === common_1.PaddingMode.CASUAL) { | ||
if (padding === 'casual') { | ||
throw new errors_1.NotImplementedError('The support for CASUAL padding mode in conv1dWithBias is not ' + | ||
'implemented yet.'); | ||
} | ||
var y = tfc.conv1d(x, kernel, strides, padding === common_1.PaddingMode.SAME ? 'same' : 'valid'); | ||
var y = tfc.conv1d(x, kernel, strides, padding === 'same' ? 'same' : 'valid'); | ||
if (bias != null) { | ||
@@ -623,4 +626,5 @@ y = biasAdd(y, bias); | ||
if (strides === void 0) { strides = 1; } | ||
if (padding === void 0) { padding = common_1.PaddingMode.VALID; } | ||
if (padding === void 0) { padding = 'valid'; } | ||
if (dilationRate === void 0) { dilationRate = 1; } | ||
common_1.checkDataFormat(dataFormat); | ||
return conv1dWithBias(x, kernel, null, strides, padding, dataFormat, dilationRate); | ||
@@ -631,3 +635,4 @@ } | ||
if (strides === void 0) { strides = [1, 1]; } | ||
if (padding === void 0) { padding = common_1.PaddingMode.VALID; } | ||
if (padding === void 0) { padding = 'valid'; } | ||
common_1.checkDataFormat(dataFormat); | ||
return conv2dWithBias(x, kernel, null, strides, padding, dataFormat, dilationRate); | ||
@@ -638,6 +643,7 @@ } | ||
if (strides === void 0) { strides = [1, 1]; } | ||
if (padding === void 0) { padding = common_1.PaddingMode.VALID; } | ||
if (padding === void 0) { padding = 'valid'; } | ||
if (dataFormat == null) { | ||
dataFormat = common_3.imageDataFormat(); | ||
} | ||
common_1.checkDataFormat(dataFormat); | ||
if (dilationRate != null) { | ||
@@ -655,11 +661,11 @@ throw new errors_1.NotImplementedError('Support for non-default dilation rate is not implemented yet.'); | ||
var y = preprocessConv2DInput(x, dataFormat); | ||
if (padding === common_1.PaddingMode.CASUAL) { | ||
if (padding === 'casual') { | ||
throw new errors_1.NotImplementedError('The support for CASUAL padding mode in conv1dWithBias is not ' + | ||
'implemented yet.'); | ||
} | ||
y = tfc.conv2d(y, kernel, strides, padding === common_1.PaddingMode.SAME ? 'same' : 'valid'); | ||
y = tfc.conv2d(y, kernel, strides, padding === 'same' ? 'same' : 'valid'); | ||
if (bias != null) { | ||
y = biasAdd(y, bias); | ||
} | ||
if (dataFormat === common_1.DataFormat.CHANNEL_FIRST) { | ||
if (dataFormat === 'channelFirst') { | ||
y = tfc.transpose(y, [0, 3, 1, 2]); | ||
@@ -672,6 +678,7 @@ } | ||
if (strides === void 0) { strides = [1, 1]; } | ||
if (padding === void 0) { padding = common_1.PaddingMode.VALID; } | ||
if (padding === void 0) { padding = 'valid'; } | ||
if (dataFormat == null) { | ||
dataFormat = common_3.imageDataFormat(); | ||
} | ||
common_1.checkDataFormat(dataFormat); | ||
var y = preprocessConv2DInput(x, dataFormat); | ||
@@ -686,4 +693,4 @@ if (ndim(x) !== 4) { | ||
} | ||
y = tfc.depthwiseConv2d(y, depthwiseKernel, strides, padding === common_1.PaddingMode.SAME ? 'same' : 'valid', dilationRate); | ||
if (dataFormat === common_1.DataFormat.CHANNEL_FIRST) { | ||
y = tfc.depthwiseConv2d(y, depthwiseKernel, strides, padding === 'same' ? 'same' : 'valid', dilationRate); | ||
if (dataFormat === 'channelFirst') { | ||
y = tfc.transpose(y, [0, 3, 1, 2]); | ||
@@ -695,2 +702,5 @@ } | ||
function pool2d(x, poolSize, strides, padding, dataFormat, poolMode) { | ||
common_1.checkDataFormat(dataFormat); | ||
common_1.checkPoolMode(poolMode); | ||
common_1.checkPaddingMode(padding); | ||
if (strides == null) { | ||
@@ -700,3 +710,3 @@ strides = [1, 1]; | ||
if (padding == null) { | ||
padding = common_1.PaddingMode.VALID; | ||
padding = 'valid'; | ||
} | ||
@@ -707,8 +717,8 @@ if (dataFormat == null) { | ||
if (poolMode == null) { | ||
poolMode = common_1.PoolMode.MAX; | ||
poolMode = 'max'; | ||
} | ||
x = preprocessConv2DInput(x, dataFormat); | ||
var y; | ||
var paddingString = (padding === common_1.PaddingMode.SAME) ? 'same' : 'valid'; | ||
if (poolMode === common_1.PoolMode.MAX) { | ||
var paddingString = (padding === 'same') ? 'same' : 'valid'; | ||
if (poolMode === 'max') { | ||
y = tfc.maxPool(x, poolSize, strides, paddingString); | ||
@@ -719,3 +729,3 @@ } | ||
} | ||
if (dataFormat === common_1.DataFormat.CHANNEL_FIRST) { | ||
if (dataFormat === 'channelFirst') { | ||
y = tfc.transpose(y, [0, 3, 1, 2]); | ||
@@ -722,0 +732,0 @@ } |
@@ -49,2 +49,3 @@ import { Scalar, Tensor } from '@tensorflow/tfjs-core'; | ||
export declare function resolveScalarsInLogs(logs: UnresolvedLogs): Promise<void>; | ||
export declare function disposeTensorsInLogs(logs: UnresolvedLogs): void; | ||
export declare class History extends Callback { | ||
@@ -51,0 +52,0 @@ epoch: number[]; |
@@ -396,2 +396,14 @@ "use strict"; | ||
exports.resolveScalarsInLogs = resolveScalarsInLogs; | ||
function disposeTensorsInLogs(logs) { | ||
if (logs == null) { | ||
return; | ||
} | ||
for (var key in logs) { | ||
var value = logs[key]; | ||
if (typeof value !== 'number') { | ||
value.dispose(); | ||
} | ||
} | ||
} | ||
exports.disposeTensorsInLogs = disposeTensorsInLogs; | ||
var History = (function (_super) { | ||
@@ -398,0 +410,0 @@ __extends(History, _super); |
@@ -1,16 +0,12 @@ | ||
export declare enum DataFormat { | ||
CHANNEL_FIRST = 0, | ||
CHANNEL_LAST = 1, | ||
} | ||
export declare enum PaddingMode { | ||
VALID = 0, | ||
SAME = 1, | ||
CASUAL = 2, | ||
} | ||
export declare enum PoolMode { | ||
MAX = 0, | ||
AVG = 1, | ||
} | ||
export declare type DataFormat = 'channelFirst' | 'channelLast'; | ||
export declare const VALID_DATA_FORMAT_VALUES: string[]; | ||
export declare function checkDataFormat(value?: string): void; | ||
export declare type PaddingMode = 'valid' | 'same' | 'casual'; | ||
export declare const VALID_PADDING_MODE_VALUES: string[]; | ||
export declare function checkPaddingMode(value?: string): void; | ||
export declare type PoolMode = 'max' | 'avg'; | ||
export declare const VALID_POOL_MODE_VALUES: string[]; | ||
export declare function checkPoolMode(value?: string): void; | ||
export declare function nameScope<T>(name: string, fn: () => T): T; | ||
export declare function getUniqueTensorName(prefix: string): string; | ||
export declare function isValidTensorName(name: string): boolean; |
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
var errors_1 = require("./errors"); | ||
var generic_utils_1 = require("./utils/generic_utils"); | ||
var nameMap = new Map(); | ||
var DataFormat; | ||
(function (DataFormat) { | ||
DataFormat[DataFormat["CHANNEL_FIRST"] = 0] = "CHANNEL_FIRST"; | ||
DataFormat[DataFormat["CHANNEL_LAST"] = 1] = "CHANNEL_LAST"; | ||
})(DataFormat = exports.DataFormat || (exports.DataFormat = {})); | ||
generic_utils_1.SerializableEnumRegistry.register('data_format', { | ||
'channels_first': DataFormat.CHANNEL_FIRST, | ||
'channels_last': DataFormat.CHANNEL_LAST | ||
}); | ||
var PaddingMode; | ||
(function (PaddingMode) { | ||
PaddingMode[PaddingMode["VALID"] = 0] = "VALID"; | ||
PaddingMode[PaddingMode["SAME"] = 1] = "SAME"; | ||
PaddingMode[PaddingMode["CASUAL"] = 2] = "CASUAL"; | ||
})(PaddingMode = exports.PaddingMode || (exports.PaddingMode = {})); | ||
generic_utils_1.SerializableEnumRegistry.register('padding', { 'valid': PaddingMode.VALID, 'same': PaddingMode.SAME }); | ||
var PoolMode; | ||
(function (PoolMode) { | ||
PoolMode[PoolMode["MAX"] = 0] = "MAX"; | ||
PoolMode[PoolMode["AVG"] = 1] = "AVG"; | ||
})(PoolMode = exports.PoolMode || (exports.PoolMode = {})); | ||
generic_utils_1.SerializableEnumRegistry.register('data_format', { 'channels_first': 'channelFirst', 'channels_last': 'channelLast' }); | ||
exports.VALID_DATA_FORMAT_VALUES = ['channelFirst', 'channelLast', undefined, null]; | ||
function checkDataFormat(value) { | ||
if (value == null) { | ||
return; | ||
} | ||
if (exports.VALID_DATA_FORMAT_VALUES.indexOf(value) < 0) { | ||
throw new errors_1.ValueError(value + " is not a valid DataFormat. Valid values as " + exports.VALID_DATA_FORMAT_VALUES); | ||
} | ||
} | ||
exports.checkDataFormat = checkDataFormat; | ||
generic_utils_1.SerializableEnumRegistry.register('padding', { 'valid': 'valid', 'same': 'same', 'casual': 'casual' }); | ||
exports.VALID_PADDING_MODE_VALUES = ['valid', 'same', 'casual', undefined, null]; | ||
function checkPaddingMode(value) { | ||
if (value == null) { | ||
return; | ||
} | ||
if (exports.VALID_PADDING_MODE_VALUES.indexOf(value) < 0) { | ||
throw new errors_1.ValueError(value + " is not a valid PaddingMode. Valid values as " + exports.VALID_PADDING_MODE_VALUES); | ||
} | ||
} | ||
exports.checkPaddingMode = checkPaddingMode; | ||
exports.VALID_POOL_MODE_VALUES = ['max', 'avg', undefined, null]; | ||
function checkPoolMode(value) { | ||
if (value == null) { | ||
return; | ||
} | ||
if (exports.VALID_POOL_MODE_VALUES.indexOf(value) < 0) { | ||
throw new errors_1.ValueError(value + " is not a valid PoolMode. Valid values as " + exports.VALID_POOL_MODE_VALUES); | ||
} | ||
} | ||
exports.checkPoolMode = checkPoolMode; | ||
var _nameScopeStack = []; | ||
@@ -27,0 +39,0 @@ var _nameScopeDivider = '/'; |
@@ -81,6 +81,6 @@ import { Optimizer, Scalar, Tensor, Tensor1D } from '@tensorflow/tfjs-core'; | ||
private checkTrainableWeightsConsistency(); | ||
evaluate(x: Tensor | Tensor[], y: Tensor | Tensor[], config?: ModelEvaluateConfig): Promise<Scalar | Scalar[]>; | ||
evaluate(x: Tensor | Tensor[], y: Tensor | Tensor[], config?: ModelEvaluateConfig): Scalar | Scalar[]; | ||
private checkNumSamples(ins, batchSize?, steps?, stepsName?); | ||
private predictLoop(ins, batchSize?, verbose?); | ||
predict(x: Tensor | Tensor[], config?: ModelPredictConfig): Promise<Tensor | Tensor[]>; | ||
predict(x: Tensor | Tensor[], config?: ModelPredictConfig): Tensor | Tensor[]; | ||
predictOnBatch(x: Tensor): Tensor | Tensor[]; | ||
@@ -87,0 +87,0 @@ protected standardizeUserData(x: Tensor | Tensor[] | { |
@@ -519,14 +519,9 @@ "use strict"; | ||
if (config === void 0) { config = {}; } | ||
return __awaiter(this, void 0, void 0, function () { | ||
var batchSize, standardizedOuts, ins, f, testOuts; | ||
return __generator(this, function (_a) { | ||
batchSize = config.batchSize == null ? 32 : config.batchSize; | ||
standardizedOuts = this.standardizeUserData(x, y, true, batchSize); | ||
ins = standardizedOuts[0].concat(standardizedOuts[1]); | ||
this.makeTestFunction(); | ||
f = this.testFunction; | ||
testOuts = this.testLoop(f, ins, batchSize, config.verbose, config.steps); | ||
return [2, generic_utils_1.singletonOrArray(testOuts)]; | ||
}); | ||
}); | ||
var batchSize = config.batchSize == null ? 32 : config.batchSize; | ||
var standardizedOuts = this.standardizeUserData(x, y, true, batchSize); | ||
var ins = standardizedOuts[0].concat(standardizedOuts[1]); | ||
this.makeTestFunction(); | ||
var f = this.testFunction; | ||
var testOuts = this.testLoop(f, ins, batchSize, config.verbose, config.steps); | ||
return generic_utils_1.singletonOrArray(testOuts); | ||
}; | ||
@@ -603,10 +598,5 @@ Model.prototype.checkNumSamples = function (ins, batchSize, steps, stepsName) { | ||
if (config === void 0) { config = {}; } | ||
return __awaiter(this, void 0, void 0, function () { | ||
var batchSize; | ||
return __generator(this, function (_a) { | ||
checkInputData(x, this.inputNames, this.feedInputShapes, false); | ||
batchSize = config.batchSize == null ? 32 : config.batchSize; | ||
return [2, this.predictLoop(x, batchSize)]; | ||
}); | ||
}); | ||
checkInputData(x, this.inputNames, this.feedInputShapes, false); | ||
var batchSize = config.batchSize == null ? 32 : config.batchSize; | ||
return this.predictLoop(x, batchSize); | ||
}; | ||
@@ -759,3 +749,2 @@ Model.prototype.predictOnBatch = function (x) { | ||
} | ||
return outs; | ||
}); | ||
@@ -765,2 +754,3 @@ return [4, callbackList.onBatchEnd(batchIndex, batchLogs)]; | ||
_a.sent(); | ||
callbacks_1.disposeTensorsInLogs(batchLogs); | ||
return [2]; | ||
@@ -781,4 +771,7 @@ } | ||
return [3, 3]; | ||
case 6: return [4, callbackList.onEpochEnd(epoch, epochLogs)]; | ||
case 7: | ||
case 6: | ||
epochIndexArray1D_1.dispose(); | ||
_a.label = 7; | ||
case 7: return [4, callbackList.onEpochEnd(epoch, epochLogs)]; | ||
case 8: | ||
_a.sent(); | ||
@@ -961,3 +954,2 @@ return [2]; | ||
var meanLoss = K.mean(loss); | ||
K.keep(meanLoss); | ||
lossValues.push(meanLoss); | ||
@@ -981,8 +973,3 @@ if (i === 0) { | ||
}; | ||
_this.optimizer.updateVariables(totalLossFunction, _this.collectedTrainableWeights); | ||
var totalLossValue = lossValues[0]; | ||
for (var i = 1; i < lossValues.length; ++i) { | ||
totalLossValue = | ||
K.scalarPlusArray(totalLossValue, lossValues[i]); | ||
} | ||
var totalLossValue = _this.optimizer.updateVariables(totalLossFunction, _this.collectedTrainableWeights); | ||
return [totalLossValue].concat(metricsValues); | ||
@@ -989,0 +976,0 @@ }; |
@@ -9,3 +9,3 @@ import { MaxNorm, MaxNormConfig, MinMaxNorm, MinMaxNormConfig, NonNeg, UnitNorm, UnitNormConfig } from './constraints'; | ||
import { EmbeddingLayerConfig } from './layers/embeddings'; | ||
import { ConcatenateLayerConfig, MergeLayerConfig } from './layers/merge'; | ||
import { ConcatenateLayerConfig } from './layers/merge'; | ||
import { BatchNormalizationLayerConfig } from './layers/normalization'; | ||
@@ -26,2 +26,4 @@ import { GlobalPooling2DLayerConfig, Pooling1DLayerConfig, Pooling2DLayerConfig } from './layers/pooling'; | ||
export declare class LayerExports { | ||
static Layer: typeof Layer; | ||
static RNNCell: typeof RNNCell; | ||
static conv1d(config: ConvLayerConfig): Layer; | ||
@@ -36,8 +38,8 @@ static conv2d(config: ConvLayerConfig): Layer; | ||
static embedding(config: EmbeddingLayerConfig): Layer; | ||
static add(config: MergeLayerConfig): Layer; | ||
static average(config: MergeLayerConfig): Layer; | ||
static add(config: LayerConfig): Layer; | ||
static average(config: LayerConfig): Layer; | ||
static concatenate(config: ConcatenateLayerConfig): Layer; | ||
static maximum(config: MergeLayerConfig): Layer; | ||
static minimum(config: MergeLayerConfig): Layer; | ||
static multiply(config: MergeLayerConfig): Layer; | ||
static maximum(config: LayerConfig): Layer; | ||
static minimum(config: LayerConfig): Layer; | ||
static multiply(config: LayerConfig): Layer; | ||
static batchNormalization(config: BatchNormalizationLayerConfig): Layer; | ||
@@ -44,0 +46,0 @@ static avgPooling1d(config: Pooling1DLayerConfig): Layer; |
@@ -190,2 +190,4 @@ "use strict"; | ||
}; | ||
LayerExports.Layer = topology_1.Layer; | ||
LayerExports.RNNCell = recurrent_1.RNNCell; | ||
__decorate([ | ||
@@ -192,0 +194,0 @@ tfjs_core_1.doc({ |
@@ -5,5 +5,4 @@ import * as dl from '@tensorflow/tfjs-core'; | ||
export { Callback, CallbackList, CustomCallback, CustomCallbackConfig, Logs } from './callbacks'; | ||
export { Layer } from './engine/topology'; | ||
export { Model, ModelCompileConfig, ModelEvaluateConfig, ModelFitConfig, ModelPredictConfig } from './engine/training'; | ||
export { GRUCellLayerConfig, GRULayerConfig, LSTMCellLayerConfig, LSTMLayerConfig, RNN, RNNCell, RNNLayerConfig, SimpleRNNCellLayerConfig, SimpleRNNLayerConfig } from './layers/recurrent'; | ||
export { GRUCellLayerConfig, GRULayerConfig, LSTMCellLayerConfig, LSTMLayerConfig, RNN, RNNLayerConfig, SimpleRNNCellLayerConfig, SimpleRNNLayerConfig } from './layers/recurrent'; | ||
export { ModelAndWeightsConfig, Sequential, SequentialConfig } from './models'; | ||
@@ -10,0 +9,0 @@ export { SymbolicTensor } from './types'; |
@@ -12,4 +12,2 @@ "use strict"; | ||
exports.CustomCallback = callbacks_1.CustomCallback; | ||
var topology_1 = require("./engine/topology"); | ||
exports.Layer = topology_1.Layer; | ||
var training_1 = require("./engine/training"); | ||
@@ -19,3 +17,2 @@ exports.Model = training_1.Model; | ||
exports.RNN = recurrent_1.RNN; | ||
exports.RNNCell = recurrent_1.RNNCell; | ||
var models_1 = require("./models"); | ||
@@ -22,0 +19,0 @@ exports.Sequential = models_1.Sequential; |
@@ -5,11 +5,8 @@ import { Tensor } from '@tensorflow/tfjs-core'; | ||
import { Constructor } from './utils/generic_utils'; | ||
export declare enum FanMode { | ||
FAN_IN = 0, | ||
FAN_OUT = 1, | ||
FAN_AVG = 2, | ||
} | ||
export declare enum Distribution { | ||
NORMAL = 0, | ||
UNIFORM = 1, | ||
} | ||
export declare type FanMode = 'fanIn' | 'fanOut' | 'fanAvg'; | ||
export declare const VALID_FAN_MODE_VALUES: string[]; | ||
export declare function checkFanMode(value?: string): void; | ||
export declare type Distribution = 'normal' | 'uniform'; | ||
export declare const VALID_DISTRIBUTION_VALUES: string[]; | ||
export declare function checkDistribution(value?: string): void; | ||
export declare abstract class Initializer { | ||
@@ -16,0 +13,0 @@ static fromConfig<T>(cls: Constructor<T>, config: ConfigDict): T; |
@@ -26,19 +26,24 @@ "use strict"; | ||
var math_utils_1 = require("./utils/math_utils"); | ||
var FanMode; | ||
(function (FanMode) { | ||
FanMode[FanMode["FAN_IN"] = 0] = "FAN_IN"; | ||
FanMode[FanMode["FAN_OUT"] = 1] = "FAN_OUT"; | ||
FanMode[FanMode["FAN_AVG"] = 2] = "FAN_AVG"; | ||
})(FanMode = exports.FanMode || (exports.FanMode = {})); | ||
generic_utils_1.SerializableEnumRegistry.register('mode', { | ||
'fan_in': FanMode.FAN_IN, | ||
'fan_out': FanMode.FAN_OUT, | ||
'fan_avg': FanMode.FAN_AVG | ||
}); | ||
var Distribution; | ||
(function (Distribution) { | ||
Distribution[Distribution["NORMAL"] = 0] = "NORMAL"; | ||
Distribution[Distribution["UNIFORM"] = 1] = "UNIFORM"; | ||
})(Distribution = exports.Distribution || (exports.Distribution = {})); | ||
generic_utils_1.SerializableEnumRegistry.register('distribution', { 'normal': Distribution.NORMAL, 'uniform': Distribution.UNIFORM }); | ||
generic_utils_1.SerializableEnumRegistry.register('mode', { 'fan_in': 'fanIn', 'fan_out': 'fanOut', 'fan_avg': 'fanAvg' }); | ||
exports.VALID_FAN_MODE_VALUES = ['fanIn', 'fanOut', 'fanAvg', undefined, null]; | ||
function checkFanMode(value) { | ||
if (value == null) { | ||
return; | ||
} | ||
if (exports.VALID_FAN_MODE_VALUES.indexOf(value) < 0) { | ||
throw new errors_1.ValueError(value + " is not a valid FanMode. Valid values as " + exports.VALID_FAN_MODE_VALUES); | ||
} | ||
} | ||
exports.checkFanMode = checkFanMode; | ||
generic_utils_1.SerializableEnumRegistry.register('distribution', { 'normal': 'normal', 'uniform': 'uniform' }); | ||
exports.VALID_DISTRIBUTION_VALUES = ['normal', 'uniform', undefined, null]; | ||
function checkDistribution(value) { | ||
if (value == null) { | ||
return; | ||
} | ||
if (exports.VALID_DISTRIBUTION_VALUES.indexOf(value) < 0) { | ||
throw new errors_1.ValueError(value + " is not a valid Distribution. Valid values as " + exports.VALID_DISTRIBUTION_VALUES); | ||
} | ||
} | ||
exports.checkDistribution = checkDistribution; | ||
var Initializer = (function () { | ||
@@ -192,5 +197,6 @@ function Initializer() { | ||
function computeFans(shape, dataFormat) { | ||
if (dataFormat === void 0) { dataFormat = common_1.DataFormat.CHANNEL_LAST; } | ||
if (dataFormat === void 0) { dataFormat = 'channelLast'; } | ||
var fanIn; | ||
var fanOut; | ||
common_1.checkDataFormat(dataFormat); | ||
if (shape.length === 2) { | ||
@@ -201,3 +207,3 @@ fanIn = shape[0]; | ||
else if (_.contains([3, 4, 5], shape.length)) { | ||
if (dataFormat === common_1.DataFormat.CHANNEL_FIRST) { | ||
if (dataFormat === 'channelFirst') { | ||
var receptiveFieldSize = math_utils_1.arrayProd(shape, 2); | ||
@@ -207,3 +213,3 @@ fanIn = shape[1] * receptiveFieldSize; | ||
} | ||
else if (dataFormat === common_1.DataFormat.CHANNEL_LAST) { | ||
else if (dataFormat === 'channelLast') { | ||
var receptiveFieldSize = math_utils_1.arrayProd(shape, 0, shape.length - 2); | ||
@@ -213,5 +219,2 @@ fanIn = shape[shape.length - 2] * receptiveFieldSize; | ||
} | ||
else { | ||
throw new errors_1.ValueError("Invalid dataFormat: " + dataFormat); | ||
} | ||
} | ||
@@ -234,3 +237,5 @@ else { | ||
_this.mode = config.mode; | ||
checkFanMode(_this.mode); | ||
_this.distribution = config.distribution; | ||
checkDistribution(_this.distribution); | ||
_this.seed = config.seed; | ||
@@ -244,6 +249,6 @@ return _this; | ||
var scale = this.scale; | ||
if (this.mode === FanMode.FAN_IN) { | ||
if (this.mode === 'fanIn') { | ||
scale /= Math.max(1, fanIn); | ||
} | ||
else if (this.mode === FanMode.FAN_OUT) { | ||
else if (this.mode === 'fanOut') { | ||
scale /= Math.max(1, fanOut); | ||
@@ -254,3 +259,3 @@ } | ||
} | ||
if (this.distribution === Distribution.NORMAL) { | ||
if (this.distribution === 'normal') { | ||
var stddev = Math.sqrt(scale); | ||
@@ -281,4 +286,4 @@ return K.truncatedNormal(shape, 0, stddev, dtype, this.seed); | ||
scale: 1.0, | ||
mode: FanMode.FAN_AVG, | ||
distribution: Distribution.UNIFORM, | ||
mode: 'fanAvg', | ||
distribution: 'uniform', | ||
seed: config.seed | ||
@@ -296,4 +301,4 @@ }) || this; | ||
scale: 1.0, | ||
mode: FanMode.FAN_AVG, | ||
distribution: Distribution.NORMAL, | ||
mode: 'fanAvg', | ||
distribution: 'normal', | ||
seed: config.seed | ||
@@ -309,8 +314,3 @@ }) || this; | ||
function HeNormal(config) { | ||
return _super.call(this, { | ||
scale: 2.0, | ||
mode: FanMode.FAN_IN, | ||
distribution: Distribution.NORMAL, | ||
seed: config.seed | ||
}) || this; | ||
return _super.call(this, { scale: 2.0, mode: 'fanIn', distribution: 'normal', seed: config.seed }) || this; | ||
} | ||
@@ -324,8 +324,3 @@ return HeNormal; | ||
function LeCunNormal(config) { | ||
return _super.call(this, { | ||
scale: 1.0, | ||
mode: FanMode.FAN_IN, | ||
distribution: Distribution.NORMAL, | ||
seed: config.seed | ||
}) || this; | ||
return _super.call(this, { scale: 1.0, mode: 'fanIn', distribution: 'normal', seed: config.seed }) || this; | ||
} | ||
@@ -332,0 +327,0 @@ return LeCunNormal; |
@@ -14,3 +14,2 @@ "use strict"; | ||
var K = require("../backend/deeplearnjs_backend"); | ||
var common_1 = require("../common"); | ||
var constraints_1 = require("../constraints"); | ||
@@ -42,3 +41,3 @@ var errors_1 = require("../errors"); | ||
} | ||
var channelAxis = this.dataFormat === common_1.DataFormat.CHANNEL_FIRST ? 1 : 3; | ||
var channelAxis = this.dataFormat === 'channelFirst' ? 1 : 3; | ||
if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) { | ||
@@ -74,7 +73,5 @@ throw new errors_1.ValueError('The channel dimension of the inputs to DepthwiseConv2D should ' + | ||
inputShape = generic_utils_1.getExactlyOneShape(inputShape); | ||
var rows = this.dataFormat === common_1.DataFormat.CHANNEL_FIRST ? inputShape[2] : | ||
inputShape[1]; | ||
var cols = this.dataFormat === common_1.DataFormat.CHANNEL_FIRST ? inputShape[3] : | ||
inputShape[2]; | ||
var outFilters = this.dataFormat === common_1.DataFormat.CHANNEL_FIRST ? | ||
var rows = this.dataFormat === 'channelFirst' ? inputShape[2] : inputShape[1]; | ||
var cols = this.dataFormat === 'channelFirst' ? inputShape[3] : inputShape[2]; | ||
var outFilters = this.dataFormat === 'channelFirst' ? | ||
inputShape[1] * this.depthMultiplier : | ||
@@ -84,3 +81,3 @@ inputShape[3] * this.depthMultiplier; | ||
var outCols = conv_utils_1.convOutputLength(cols, this.kernelSize[1], this.padding, this.strides[1]); | ||
if (this.dataFormat === common_1.DataFormat.CHANNEL_FIRST) { | ||
if (this.dataFormat === 'channelFirst') { | ||
return [inputShape[0], outFilters, outRows, outCols]; | ||
@@ -87,0 +84,0 @@ } |
@@ -40,5 +40,7 @@ "use strict"; | ||
_this.strides = conv_utils_1.normalizeArray(config.strides == null ? 1 : config.strides, rank, 'strides'); | ||
_this.padding = config.padding == null ? common_1.PaddingMode.VALID : config.padding; | ||
_this.padding = config.padding == null ? 'valid' : config.padding; | ||
common_1.checkPaddingMode(_this.padding); | ||
_this.dataFormat = | ||
config.dataFormat == null ? common_1.DataFormat.CHANNEL_LAST : config.dataFormat; | ||
config.dataFormat == null ? 'channelLast' : config.dataFormat; | ||
common_1.checkDataFormat(_this.dataFormat); | ||
_this.dilationRate = config.dilationRate == null ? 1 : config.dilationRate; | ||
@@ -65,5 +67,3 @@ if (!(_this.dilationRate === 1 || | ||
inputShape = generic_utils.getExactlyOneShape(inputShape); | ||
var channelAxis = this.dataFormat === common_1.DataFormat.CHANNEL_FIRST ? | ||
1 : | ||
inputShape.length - 1; | ||
var channelAxis = this.dataFormat === 'channelFirst' ? 1 : inputShape.length - 1; | ||
if (inputShape[channelAxis] == null) { | ||
@@ -104,3 +104,3 @@ throw new errors_1.ValueError("The channel dimension of the input should be defined. " + | ||
var newSpace = []; | ||
var space = (this.dataFormat === common_1.DataFormat.CHANNEL_LAST) ? | ||
var space = (this.dataFormat === 'channelLast') ? | ||
inputShape.slice(1, inputShape.length - 1) : | ||
@@ -114,3 +114,3 @@ inputShape.slice(2); | ||
var outputShape = [inputShape[0]]; | ||
if (this.dataFormat === common_1.DataFormat.CHANNEL_LAST) { | ||
if (this.dataFormat === 'channelLast') { | ||
outputShape = outputShape.concat(newSpace); | ||
@@ -117,0 +117,0 @@ outputShape.push(this.filters); |
@@ -63,3 +63,3 @@ import { Tensor } from '@tensorflow/tfjs-core'; | ||
export interface ActivationLayerConfig extends LayerConfig { | ||
activation: string; | ||
activation: ActivationIdentifier; | ||
} | ||
@@ -66,0 +66,0 @@ export declare class Activation extends Layer { |
import { Tensor } from '@tensorflow/tfjs-core'; | ||
import { Layer, LayerConfig } from '../engine/topology'; | ||
import { Shape } from '../types'; | ||
export interface MergeLayerConfig extends LayerConfig { | ||
} | ||
export declare class Merge extends Layer { | ||
protected reshapeRequired: boolean; | ||
constructor(config: MergeLayerConfig); | ||
constructor(config?: LayerConfig); | ||
protected mergeFunction(inputs: Tensor[]): Tensor; | ||
@@ -16,17 +14,22 @@ private computeElementwiseOpOutputShape(shape1, shape2); | ||
export declare class Add extends Merge { | ||
constructor(config?: LayerConfig); | ||
protected mergeFunction(inputs: Tensor[]): Tensor; | ||
} | ||
export declare class Multiply extends Merge { | ||
constructor(config?: LayerConfig); | ||
protected mergeFunction(inputs: Tensor[]): Tensor; | ||
} | ||
export declare class Average extends Merge { | ||
constructor(config?: LayerConfig); | ||
protected mergeFunction(inputs: Tensor[]): Tensor; | ||
} | ||
export declare class Maximum extends Merge { | ||
constructor(config?: LayerConfig); | ||
protected mergeFunction(inputs: Tensor[]): Tensor; | ||
} | ||
export declare class Minimum extends Merge { | ||
constructor(config?: LayerConfig); | ||
protected mergeFunction(inputs: Tensor[]): Tensor; | ||
} | ||
export interface ConcatenateLayerConfig extends MergeLayerConfig { | ||
export interface ConcatenateLayerConfig extends LayerConfig { | ||
axis?: number; | ||
@@ -33,0 +36,0 @@ } |
@@ -22,3 +22,3 @@ "use strict"; | ||
function Merge(config) { | ||
var _this = _super.call(this, config) || this; | ||
var _this = _super.call(this, config || {}) || this; | ||
_this.supportsMasking = true; | ||
@@ -194,4 +194,4 @@ return _this; | ||
__extends(Add, _super); | ||
function Add() { | ||
return _super !== null && _super.apply(this, arguments) || this; | ||
function Add(config) { | ||
return _super.call(this, config) || this; | ||
} | ||
@@ -212,4 +212,4 @@ Add.prototype.mergeFunction = function (inputs) { | ||
__extends(Multiply, _super); | ||
function Multiply() { | ||
return _super !== null && _super.apply(this, arguments) || this; | ||
function Multiply(config) { | ||
return _super.call(this, config) || this; | ||
} | ||
@@ -230,4 +230,4 @@ Multiply.prototype.mergeFunction = function (inputs) { | ||
__extends(Average, _super); | ||
function Average() { | ||
return _super !== null && _super.apply(this, arguments) || this; | ||
function Average(config) { | ||
return _super.call(this, config) || this; | ||
} | ||
@@ -248,4 +248,4 @@ Average.prototype.mergeFunction = function (inputs) { | ||
__extends(Maximum, _super); | ||
function Maximum() { | ||
return _super !== null && _super.apply(this, arguments) || this; | ||
function Maximum(config) { | ||
return _super.call(this, config) || this; | ||
} | ||
@@ -266,4 +266,4 @@ Maximum.prototype.mergeFunction = function (inputs) { | ||
__extends(Minimum, _super); | ||
function Minimum() { | ||
return _super !== null && _super.apply(this, arguments) || this; | ||
function Minimum(config) { | ||
return _super.call(this, config) || this; | ||
} | ||
@@ -270,0 +270,0 @@ Minimum.prototype.mergeFunction = function (inputs) { |
@@ -30,3 +30,4 @@ "use strict"; | ||
_this.strides = config.strides == null ? _this.poolSize : [config.strides]; | ||
_this.padding = config.padding == null ? common_1.PaddingMode.VALID : config.padding; | ||
_this.padding = config.padding == null ? 'valid' : config.padding; | ||
common_1.checkPaddingMode(_this.padding); | ||
_this.inputSpec = [new topology_1.InputSpec({ ndim: 3 })]; | ||
@@ -43,3 +44,3 @@ return _this; | ||
inputs = K.expandDims(generic_utils.getExactlyOneTensor(inputs), 2); | ||
var output = this.poolingFunction(generic_utils.getExactlyOneTensor(inputs), [this.poolSize[0], 1], [this.strides[0], 1], this.padding, common_1.DataFormat.CHANNEL_LAST); | ||
var output = this.poolingFunction(generic_utils.getExactlyOneTensor(inputs), [this.poolSize[0], 1], [this.strides[0], 1], this.padding, 'channelLast'); | ||
return K.squeeze(output, 2); | ||
@@ -66,3 +67,5 @@ }; | ||
MaxPooling1D.prototype.poolingFunction = function (inputs, poolSize, strides, padding, dataFormat) { | ||
return K.pool2d(inputs, poolSize, strides, padding, dataFormat, common_1.PoolMode.MAX); | ||
common_1.checkDataFormat(dataFormat); | ||
common_1.checkPaddingMode(padding); | ||
return K.pool2d(inputs, poolSize, strides, padding, dataFormat, 'max'); | ||
}; | ||
@@ -79,3 +82,5 @@ return MaxPooling1D; | ||
AvgPooling1D.prototype.poolingFunction = function (inputs, poolSize, strides, padding, dataFormat) { | ||
return K.pool2d(inputs, poolSize, strides, padding, dataFormat, common_1.PoolMode.AVG); | ||
common_1.checkDataFormat(dataFormat); | ||
common_1.checkPaddingMode(padding); | ||
return K.pool2d(inputs, poolSize, strides, padding, dataFormat, 'avg'); | ||
}; | ||
@@ -98,5 +103,7 @@ return AvgPooling1D; | ||
_this.strides = config.strides == null ? _this.poolSize : config.strides; | ||
_this.padding = config.padding == null ? common_1.PaddingMode.VALID : config.padding; | ||
_this.padding = config.padding == null ? 'valid' : config.padding; | ||
_this.dataFormat = | ||
config.dataFormat == null ? common_1.DataFormat.CHANNEL_LAST : config.dataFormat; | ||
config.dataFormat == null ? 'channelLast' : config.dataFormat; | ||
common_1.checkDataFormat(_this.dataFormat); | ||
common_1.checkPaddingMode(_this.padding); | ||
_this.inputSpec = [new topology_1.InputSpec({ ndim: 4 })]; | ||
@@ -107,6 +114,4 @@ return _this; | ||
inputShape = generic_utils.getExactlyOneShape(inputShape); | ||
var rows = this.dataFormat === common_1.DataFormat.CHANNEL_FIRST ? inputShape[2] : | ||
inputShape[1]; | ||
var cols = this.dataFormat === common_1.DataFormat.CHANNEL_FIRST ? inputShape[3] : | ||
inputShape[2]; | ||
var rows = this.dataFormat === 'channelFirst' ? inputShape[2] : inputShape[1]; | ||
var cols = this.dataFormat === 'channelFirst' ? inputShape[3] : inputShape[2]; | ||
rows = | ||
@@ -116,3 +121,3 @@ conv_utils_1.convOutputLength(rows, this.poolSize[0], this.padding, this.strides[0]); | ||
conv_utils_1.convOutputLength(cols, this.poolSize[1], this.padding, this.strides[1]); | ||
if (this.dataFormat === common_1.DataFormat.CHANNEL_FIRST) { | ||
if (this.dataFormat === 'channelFirst') { | ||
return [inputShape[0], inputShape[1], rows, cols]; | ||
@@ -148,3 +153,5 @@ } | ||
MaxPooling2D.prototype.poolingFunction = function (inputs, poolSize, strides, padding, dataFormat) { | ||
return K.pool2d(inputs, poolSize, strides, padding, dataFormat, common_1.PoolMode.MAX); | ||
common_1.checkDataFormat(dataFormat); | ||
common_1.checkPaddingMode(padding); | ||
return K.pool2d(inputs, poolSize, strides, padding, dataFormat, 'max'); | ||
}; | ||
@@ -161,3 +168,5 @@ return MaxPooling2D; | ||
AvgPooling2D.prototype.poolingFunction = function (inputs, poolSize, strides, padding, dataFormat) { | ||
return K.pool2d(inputs, poolSize, strides, padding, dataFormat, common_1.PoolMode.AVG); | ||
common_1.checkDataFormat(dataFormat); | ||
common_1.checkPaddingMode(padding); | ||
return K.pool2d(inputs, poolSize, strides, padding, dataFormat, 'avg'); | ||
}; | ||
@@ -215,3 +224,4 @@ return AvgPooling2D; | ||
_this.dataFormat = | ||
config.dataFormat == null ? common_1.DataFormat.CHANNEL_LAST : config.dataFormat; | ||
config.dataFormat == null ? 'channelLast' : config.dataFormat; | ||
common_1.checkDataFormat(_this.dataFormat); | ||
_this.inputSpec = [new topology_1.InputSpec({ ndim: 4 })]; | ||
@@ -222,3 +232,3 @@ return _this; | ||
inputShape = inputShape; | ||
if (this.dataFormat === common_1.DataFormat.CHANNEL_LAST) { | ||
if (this.dataFormat === 'channelLast') { | ||
return [inputShape[0], inputShape[3]]; | ||
@@ -249,3 +259,3 @@ } | ||
var input = generic_utils.getExactlyOneTensor(inputs); | ||
if (this.dataFormat === common_1.DataFormat.CHANNEL_LAST) { | ||
if (this.dataFormat === 'channelLast') { | ||
return K.mean(input, [1, 2]); | ||
@@ -268,3 +278,3 @@ } | ||
var input = generic_utils.getExactlyOneTensor(inputs); | ||
if (this.dataFormat === common_1.DataFormat.CHANNEL_LAST) { | ||
if (this.dataFormat === 'channelLast') { | ||
return K.max(input, [1, 2]); | ||
@@ -271,0 +281,0 @@ } |
@@ -193,3 +193,3 @@ import { Tensor } from '@tensorflow/tfjs-core'; | ||
readonly recurrentDropout: number; | ||
readonly implementatin: number; | ||
readonly implementation: number; | ||
getConfig(): ConfigDict; | ||
@@ -259,3 +259,3 @@ static fromConfig<T>(cls: generic_utils.Constructor<T>, config: ConfigDict): T; | ||
readonly recurrentDropout: number; | ||
readonly implementatin: number; | ||
readonly implementation: number; | ||
getConfig(): ConfigDict; | ||
@@ -262,0 +262,0 @@ static fromConfig<T>(cls: generic_utils.Constructor<T>, config: ConfigDict): T; |
@@ -867,3 +867,3 @@ "use strict"; | ||
}); | ||
Object.defineProperty(GRU.prototype, "implementatin", { | ||
Object.defineProperty(GRU.prototype, "implementation", { | ||
get: function () { | ||
@@ -892,3 +892,3 @@ return this.cell.implementation; | ||
recurrentDropout: this.recurrentDropout, | ||
implementation: this.implementatin, | ||
implementation: this.implementation, | ||
}; | ||
@@ -1201,3 +1201,3 @@ var baseConfig = _super.prototype.getConfig.call(this); | ||
}); | ||
Object.defineProperty(LSTM.prototype, "implementatin", { | ||
Object.defineProperty(LSTM.prototype, "implementation", { | ||
get: function () { | ||
@@ -1227,3 +1227,3 @@ return this.cell.implementation; | ||
recurrentDropout: this.recurrentDropout, | ||
implementation: this.implementatin, | ||
implementation: this.implementation, | ||
}; | ||
@@ -1230,0 +1230,0 @@ var baseConfig = _super.prototype.getConfig.call(this); |
@@ -29,4 +29,4 @@ import { Scalar, Tensor, WeightsManifestConfig } from '@tensorflow/tfjs-core'; | ||
updatable: boolean; | ||
evaluate(x: Tensor | Tensor[], y: Tensor | Tensor[], config?: ModelEvaluateConfig): Promise<Scalar | Scalar[]>; | ||
predict(x: Tensor | Tensor[], config?: ModelPredictConfig): Promise<Tensor | Tensor[]>; | ||
evaluate(x: Tensor | Tensor[], y: Tensor | Tensor[], config?: ModelEvaluateConfig): Scalar | Scalar[]; | ||
predict(x: Tensor | Tensor[], config?: ModelPredictConfig): Tensor | Tensor[]; | ||
predictOnBatch(x: Tensor): Tensor | Tensor[]; | ||
@@ -33,0 +33,0 @@ compile(config: ModelCompileConfig): void; |
@@ -259,21 +259,13 @@ "use strict"; | ||
if (config === void 0) { config = {}; } | ||
return __awaiter(this, void 0, void 0, function () { | ||
return __generator(this, function (_a) { | ||
if (!this.built) { | ||
throw new errors_1.RuntimeError('The model needs to be compiled before being used.'); | ||
} | ||
return [2, this.model.evaluate(x, y, config)]; | ||
}); | ||
}); | ||
if (!this.built) { | ||
throw new errors_1.RuntimeError('The model needs to be compiled before being used.'); | ||
} | ||
return this.model.evaluate(x, y, config); | ||
}; | ||
Sequential.prototype.predict = function (x, config) { | ||
if (config === void 0) { config = {}; } | ||
return __awaiter(this, void 0, void 0, function () { | ||
return __generator(this, function (_a) { | ||
if (this.model == null) { | ||
this.build(); | ||
} | ||
return [2, this.model.predict(x, config)]; | ||
}); | ||
}); | ||
if (this.model == null) { | ||
this.build(); | ||
} | ||
return this.model.predict(x, config); | ||
}; | ||
@@ -280,0 +272,0 @@ Sequential.prototype.predictOnBatch = function (x) { |
@@ -17,3 +17,3 @@ import { AdagradOptimizer, AdamOptimizer, Optimizer as CoreOptimizer, RMSPropOptimizer, Scalar, SGDOptimizer } from '@tensorflow/tfjs-core'; | ||
getConfig(): ConfigDict; | ||
updateVariables(lossFn: () => Scalar, params: LayerVariable[]): void; | ||
updateVariables(lossFn: () => Scalar, params: LayerVariable[]): Scalar; | ||
static fromConfig<T>(cls: Constructor<T>, config: ConfigDict): T; | ||
@@ -20,0 +20,0 @@ } |
@@ -46,3 +46,3 @@ "use strict"; | ||
var variables = params.map(function (param) { return param.read(); }); | ||
this.optimizer.minimize(lossFn, false, variables); | ||
return this.optimizer.minimize(lossFn, true, variables); | ||
}; | ||
@@ -49,0 +49,0 @@ LayersOptimizer.fromConfig = function (cls, config) { |
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
var common_1 = require("../common"); | ||
var errors_1 = require("../errors"); | ||
@@ -35,3 +34,3 @@ var generic_utils_1 = require("./generic_utils"); | ||
var outputLength; | ||
if (padding === common_1.PaddingMode.SAME) { | ||
if (padding === 'same') { | ||
outputLength = inputLength; | ||
@@ -38,0 +37,0 @@ } |
@@ -1,2 +0,2 @@ | ||
declare const version = "0.0.4"; | ||
declare const version = "0.0.5"; | ||
export { version }; |
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
var version = '0.0.4'; | ||
var version = '0.0.5'; | ||
exports.version = version; |
{ | ||
"name": "@tensorflow/tfjs-layers", | ||
"version": "0.0.4", | ||
"version": "0.0.5", | ||
"description": "TensorFlow layers API in JavaScript", | ||
@@ -5,0 +5,0 @@ "private": false, |
@@ -56,3 +56,3 @@ # TensorFlow.js Layers: High-Level Machine Learning Model API | ||
// Ater the training, perform inference. | ||
const output = await model.predict(tf.tensor2d([[5]], [1, 1])); | ||
const output = model.predict(tf.tensor2d([[5]], [1, 1])); | ||
output.print(); | ||
@@ -59,0 +59,0 @@ ``` |
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
New author
Supply chain riskA new npm collaborator published a version of the package for the first time. New collaborators are usually benign additions to a project, but do indicate a change to the security surface area of a package.
Found 1 instance in 1 package
2269210
38766
18