Comparing version 0.3.15 to 0.3.16
@@ -1,2 +0,1 @@ | ||
import { NDArrayMath } from '../math/math'; | ||
import { NDArray } from '../math/ndarray'; | ||
@@ -11,6 +10,5 @@ export interface DataStats { | ||
protected dataShapes: number[][]; | ||
protected math: NDArrayMath; | ||
protected dataset: NDArray[][] | null; | ||
private normalizationInfo; | ||
constructor(dataShapes: number[][], math: NDArrayMath); | ||
constructor(dataShapes: number[][]); | ||
getDataShape(dataIndex: number): number[]; | ||
@@ -17,0 +15,0 @@ abstract fetchData(): Promise<void>; |
@@ -7,5 +7,4 @@ "use strict"; | ||
var InMemoryDataset = (function () { | ||
function InMemoryDataset(dataShapes, math) { | ||
function InMemoryDataset(dataShapes) { | ||
this.dataShapes = dataShapes; | ||
this.math = math; | ||
this.normalizationInfo = {}; | ||
@@ -48,3 +47,2 @@ } | ||
InMemoryDataset.prototype.normalizeExamplesToRange = function (examples, curLowerBounds, curUpperBounds, newLowerBounds, newUpperBounds) { | ||
var _this = this; | ||
var curBoundsIsPerDimension = (curUpperBounds instanceof Float32Array && | ||
@@ -82,3 +80,3 @@ curLowerBounds instanceof Float32Array); | ||
} | ||
newExamples.push(ndarray_1.NDArray.make(example.shape, { values: normalizedValues }, 'float32', _this.math)); | ||
newExamples.push(ndarray_1.NDArray.make(example.shape, { values: normalizedValues }, 'float32')); | ||
}); | ||
@@ -85,0 +83,0 @@ return newExamples; |
@@ -13,3 +13,2 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
var math_1 = require("../math/math"); | ||
var ndarray_1 = require("../math/ndarray"); | ||
@@ -36,5 +35,3 @@ var util = require("../util"); | ||
function XhrDataset(xhrDatasetConfig) { | ||
var _this = this; | ||
var safeMode = false; | ||
_this = _super.call(this, xhrDatasetConfig.data.map(function (x) { return x.shape; }), new math_1.NDArrayMath('cpu', safeMode)) || this; | ||
var _this = _super.call(this, xhrDatasetConfig.data.map(function (x) { return x.shape; })) || this; | ||
_this.xhrDatasetConfig = xhrDatasetConfig; | ||
@@ -44,3 +41,2 @@ return _this; | ||
XhrDataset.prototype.getNDArray = function (info) { | ||
var _this = this; | ||
var dataPromise = info.dataType === 'png' ? | ||
@@ -54,3 +50,3 @@ parseTypedArrayFromPng(info, info.shape) : | ||
var values = data.subarray(i * inputSize, (i + 1) * inputSize); | ||
var ndarray = ndarray_1.NDArray.make(info.shape, { values: new Float32Array(values) }, 'float32', _this.math); | ||
var ndarray = ndarray_1.NDArray.make(info.shape, { values: new Float32Array(values) }, 'float32'); | ||
ndarrays.push(ndarray); | ||
@@ -57,0 +53,0 @@ } |
@@ -328,3 +328,3 @@ "use strict"; | ||
MathBackendCPU.prototype.multiply = function (a, b) { | ||
return this.broadcastedBinaryOp(a, b, a.dtype, function (aValue, bValue) { return aValue * bValue; }); | ||
return this.broadcastedBinaryOp(a, b, types.upcastType(a.dtype, b.dtype), function (aValue, bValue) { return aValue * bValue; }); | ||
}; | ||
@@ -331,0 +331,0 @@ MathBackendCPU.prototype.divide = function (a, b) { |
@@ -1,14 +0,28 @@ | ||
import { NDArray, Scalar } from '../ndarray'; | ||
import { DataTypes, NDArray, Scalar } from '../ndarray'; | ||
import { MathBackend } from './backend'; | ||
import { KernelConfigRegistry } from './kernel_registry'; | ||
import { ScopeResult, ScopeResultImmediate } from './tape_util'; | ||
export declare class BackendEngine { | ||
private backend; | ||
private masterTape; | ||
private safeMode; | ||
private nextTapeNodeId; | ||
private activeTape; | ||
private gradientScopeCount; | ||
private activeScope; | ||
private scopeStack; | ||
private debugMode; | ||
constructor(backend: MathBackend); | ||
constructor(backend: MathBackend, safeMode: boolean); | ||
enableDebugMode(): void; | ||
executeKernel<K extends keyof KernelConfigRegistry, C extends KernelConfigRegistry[K]['inputAndArgs']>(kernelName: K, config: C, grad?: KernelConfigRegistry[K]['gradient']): KernelConfigRegistry[K]['output']; | ||
gradientWrt(y: Scalar, xs: NDArray[]): NDArray[]; | ||
private checkForNaN(vals, dtype, name); | ||
gradients(f: () => Scalar, xs: NDArray[], returnValue: boolean): NDArray[] | { | ||
value: Scalar; | ||
gradients: NDArray[]; | ||
}; | ||
private gradientWrt(y, xs); | ||
scope<T extends ScopeResult>(name: string, scopeFn: (keep: <D1 extends keyof DataTypes, T1 extends NDArray<D1>>(ndarray: T1) => T1, track: <D2 extends keyof DataTypes, T2 extends NDArray<D2>>(ndarray: T2) => T2) => T, gradientsMode: boolean): T; | ||
startScope(gradientsMode: boolean): void; | ||
endScope(result: ScopeResultImmediate, gradientsMode: boolean): void; | ||
keep<T extends NDArray>(result: T): T; | ||
track<G extends keyof DataTypes, T extends NDArray<G>>(result: T): T; | ||
getBackend(): MathBackend; | ||
} |
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
var util = require("../../util"); | ||
var ndarray_1 = require("../ndarray"); | ||
var kernel_registry = require("./kernel_registry"); | ||
var tape_1 = require("./tape"); | ||
var tape_util = require("./tape_util"); | ||
var BackendEngine = (function () { | ||
function BackendEngine(backend) { | ||
function BackendEngine(backend, safeMode) { | ||
this.backend = backend; | ||
this.safeMode = safeMode; | ||
this.nextTapeNodeId = 0; | ||
this.gradientScopeCount = 0; | ||
this.debugMode = false; | ||
this.masterTape = new tape_1.Tape(backend); | ||
this.activeScope = { keep: [], track: [] }; | ||
this.scopeStack = [this.activeScope]; | ||
} | ||
@@ -28,3 +33,3 @@ BackendEngine.prototype.enableDebugMode = function () { | ||
var time = util.rightPad(performance.now() - start + "ms", 9); | ||
var paddedName = util.rightPad(name, 25); | ||
var paddedName = util.rightPad(kernelName, 25); | ||
var rank = result.rank; | ||
@@ -34,24 +39,140 @@ var size = result.size; | ||
console.log("%c" + paddedName + "\t%c" + time + "\t%c" + rank + "D " + shape + "\t%c" + size, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange'); | ||
this.checkForNaN(vals, result.dtype, name); | ||
util.checkForNaN(vals, result.dtype, name); | ||
} | ||
var evaluatedNode = { | ||
name: "kernel: " + kernelName, | ||
kernel: kernelName, | ||
inputAndArgs: config, | ||
output: result, | ||
gradient: grad | ||
}; | ||
this.masterTape.addEvaluatedKernelNode(evaluatedNode); | ||
if (this.activeTape != null) { | ||
config = tape_util.stripUndefinedInputsFromInputConfig(config); | ||
var evaluatedNode = { | ||
id: this.nextTapeNodeId++, | ||
type: 'kernel', | ||
name: "kernel: " + kernelName, | ||
kernel: kernelName, | ||
inputAndArgs: config, | ||
output: result, | ||
gradient: grad | ||
}; | ||
this.activeTape.push(evaluatedNode); | ||
} | ||
return result; | ||
}; | ||
BackendEngine.prototype.gradients = function (f, xs, returnValue) { | ||
var _this = this; | ||
var gradientsMode = true; | ||
var result = this.scope('gradients', function () { | ||
var y = f(); | ||
if (y.rank !== 0) { | ||
throw new Error("Cannot compute gradient of non-scalar y output. " + | ||
("Got y with rank " + y.rank)); | ||
} | ||
var gradients = _this.gradientWrt(y, xs); | ||
if (returnValue) { | ||
return [y].concat(gradients); | ||
} | ||
else { | ||
return gradients; | ||
} | ||
}, gradientsMode); | ||
if (returnValue) { | ||
return { value: result[0], gradients: result.slice(1) }; | ||
} | ||
else { | ||
return result; | ||
} | ||
}; | ||
BackendEngine.prototype.gradientWrt = function (y, xs) { | ||
return this.masterTape.gradientWrt(y, xs); | ||
var filteredTape = tape_util.getFilteredNodesXToY(this.activeTape, xs, y); | ||
if (filteredTape.length === 0) { | ||
throw new Error("Cannot compute gradient: y is not a function of xs."); | ||
} | ||
var arrayAccumulatedGradientMap = {}; | ||
arrayAccumulatedGradientMap[y.id] = ndarray_1.Scalar.new(1); | ||
tape_util.backpropagateGradients(arrayAccumulatedGradientMap, filteredTape); | ||
var gradients = []; | ||
for (var i = 0; i < xs.length; i++) { | ||
gradients.push(arrayAccumulatedGradientMap[xs[i].id]); | ||
} | ||
return gradients; | ||
}; | ||
BackendEngine.prototype.checkForNaN = function (vals, dtype, name) { | ||
for (var i = 0; i < vals.length; i++) { | ||
if (util.isValNaN(vals[i], dtype)) { | ||
throw Error("The result of the last math." + name + " has NaNs."); | ||
BackendEngine.prototype.scope = function (name, scopeFn, gradientsMode) { | ||
var _this = this; | ||
this.startScope(gradientsMode); | ||
var keepFn = function (ndarray) { return _this.keep(ndarray); }; | ||
var trackFn = function (ndarray) { return ndarray; }; | ||
var result = scopeFn(keepFn, trackFn); | ||
if (result instanceof Promise) { | ||
result.then(function (r) { return _this.endScope(r, gradientsMode); }); | ||
return result; | ||
} | ||
else { | ||
this.endScope(result, gradientsMode); | ||
return result; | ||
} | ||
}; | ||
BackendEngine.prototype.startScope = function (gradientsMode) { | ||
if (gradientsMode && this.gradientScopeCount === 0) { | ||
this.activeTape = []; | ||
} | ||
if (gradientsMode) { | ||
this.gradientScopeCount++; | ||
} | ||
var newScopeArrays = { keep: [], track: [] }; | ||
this.scopeStack.push(newScopeArrays); | ||
this.activeScope = newScopeArrays; | ||
}; | ||
BackendEngine.prototype.endScope = function (result, gradientsMode) { | ||
var _this = this; | ||
var arraysToKeep = this.activeScope.keep; | ||
var arraysToTrackInParent = tape_util.extractNDArraysFromScopeResult(result); | ||
arraysToKeep = arraysToKeep.concat(arraysToTrackInParent); | ||
for (var i = 0; i < this.activeScope.track.length; i++) { | ||
var ndarray = this.activeScope.track[i]; | ||
if (util.isNDArrayInList(ndarray, arraysToKeep)) { | ||
continue; | ||
} | ||
if (this.activeTape != null) { | ||
arraysToTrackInParent.push(ndarray); | ||
} | ||
else { | ||
ndarray.dispose(); | ||
} | ||
} | ||
this.scopeStack.pop(); | ||
this.activeScope = this.scopeStack.length === 0 ? | ||
null : | ||
this.scopeStack[this.scopeStack.length - 1]; | ||
arraysToTrackInParent.forEach(function (ndarray) { | ||
if (!util.isNDArrayInList(ndarray, _this.activeScope.keep)) { | ||
_this.track(ndarray); | ||
} | ||
}); | ||
if (gradientsMode) { | ||
this.gradientScopeCount--; | ||
if (this.gradientScopeCount === 0) { | ||
this.activeTape = null; | ||
} | ||
} | ||
}; | ||
BackendEngine.prototype.keep = function (result) { | ||
if (this.scopeStack.length === 1) { | ||
if (this.safeMode) { | ||
throw new Error('You are using math in safe mode. Enclose all ' + | ||
'math.method() calls inside a scope: ' + | ||
'math.scope(() => {math.method();...}) to avoid memory ' + | ||
'leaks.'); | ||
} | ||
} | ||
this.activeScope.keep.push(result); | ||
return result; | ||
}; | ||
BackendEngine.prototype.track = function (result) { | ||
if (this.scopeStack.length === 1) { | ||
if (this.safeMode) { | ||
throw new Error('You are using math in safe mode. Enclose all ' + | ||
'math.method() calls inside a scope: ' + | ||
'math.scope(() => {math.method();...}) to avoid memory ' + | ||
'leaks.'); | ||
} | ||
} | ||
this.activeScope.track.push(result); | ||
return result; | ||
}; | ||
BackendEngine.prototype.getBackend = function () { | ||
@@ -58,0 +179,0 @@ return this.backend; |
@@ -11,5 +11,8 @@ import { Conv2DInfo } from '../conv_util'; | ||
export declare class MathBackendWebGL implements MathBackend { | ||
private gpgpu; | ||
private delayedStorage; | ||
private texData; | ||
writePixels(id: number, pixels: ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement, numChannels: number): void; | ||
write<T extends keyof DataTypes>(id: number, values: DataTypes[T], dtype: T, shape: number[]): void; | ||
private getOrMakeTexData(id, shape, dtype); | ||
readSync<T extends keyof DataTypes>(id: number): DataTypes[T]; | ||
@@ -21,7 +24,6 @@ read<T extends keyof DataTypes>(id: number): Promise<DataTypes[T]>; | ||
getTextureData(id: number): TextureData; | ||
private gpgpu; | ||
private textureManager; | ||
private binaryCache; | ||
private gpgpuCreatedLocally; | ||
constructor(gpgpu?: GPGPUContext); | ||
constructor(gpgpu?: GPGPUContext, delayedStorage?: boolean); | ||
getGPGPUContext(): GPGPUContext; | ||
@@ -40,3 +42,3 @@ clone<G extends keyof DataTypes, T extends NDArray<G>>(x: T): T; | ||
matMul(a: Array2D, b: Array2D, aOrientation: MatrixOrientation, bOrientation: MatrixOrientation): Array2D; | ||
multiply<T extends NDArray>(a: T, b: T): T; | ||
multiply<G extends keyof DataTypes>(a: NDArray<G>, b: NDArray<G>): NDArray<G>; | ||
batchNormalization2D(x: Array2D, mean: Array2D | Array1D, variance: Array2D | Array1D, varianceEpsilon: number, scale?: Array2D | Array1D, offset?: Array2D | Array1D): Array2D; | ||
@@ -57,4 +59,4 @@ batchNormalization3D(x: Array3D, mean: Array3D | Array1D, variance: Array3D | Array1D, varianceEpsilon: number, scale?: Array3D | Array1D, offset?: Array3D | Array1D): Array3D; | ||
divide(a: NDArray, b: NDArray): NDArray<'float32'>; | ||
add<T extends NDArray>(a: T, b: T): T; | ||
subtract<T extends NDArray>(a: T, b: T): T; | ||
add<G extends keyof DataTypes>(a: NDArray<G>, b: NDArray<G>): NDArray<G>; | ||
subtract<G extends keyof DataTypes>(a: NDArray<G>, b: NDArray<G>): NDArray<G>; | ||
pow<T extends NDArray>(a: T, b: NDArray<'int32'>): T; | ||
@@ -105,2 +107,4 @@ ceil<T extends NDArray>(x: T): T; | ||
private throwIfNoData(id); | ||
private uploadToGPU(id); | ||
private cacheOnCPU(id, float32Values?); | ||
} | ||
@@ -107,0 +111,0 @@ export declare class NDArrayMathGPU extends NDArrayMath { |
@@ -54,2 +54,3 @@ "use strict"; | ||
var reduce_util = require("../reduce_util"); | ||
var types = require("../types"); | ||
var types_1 = require("../types"); | ||
@@ -85,3 +86,6 @@ var argminmax_gpu_1 = require("./webgl/argminmax_gpu"); | ||
var MathBackendWebGL = (function () { | ||
function MathBackendWebGL(gpgpu) { | ||
function MathBackendWebGL(gpgpu, delayedStorage) { | ||
if (delayedStorage === void 0) { delayedStorage = true; } | ||
this.gpgpu = gpgpu; | ||
this.delayedStorage = delayedStorage; | ||
this.texData = {}; | ||
@@ -98,3 +102,2 @@ this.binaryCache = {}; | ||
else { | ||
this.gpgpu = gpgpu; | ||
this.gpgpuCreatedLocally = false; | ||
@@ -105,7 +108,8 @@ } | ||
MathBackendWebGL.prototype.writePixels = function (id, pixels, numChannels) { | ||
var shape = [pixels.height, pixels.width, numChannels]; | ||
var texShape = [shape[0], shape[1]]; | ||
var texture = this.textureManager.acquireTexture(texShape); | ||
this.gpgpu.uploadPixelDataToTexture(texture, pixels); | ||
var texShape = [pixels.height, pixels.width]; | ||
var texture = id in this.texData ? | ||
this.texData[id].texture : | ||
this.textureManager.acquireTexture(texShape); | ||
this.texData[id] = { | ||
values: null, | ||
texture: texture, | ||
@@ -117,27 +121,47 @@ textureType: tex_util_1.TextureType.RGBA_COLOR, | ||
}; | ||
this.gpgpu.uploadPixelDataToTexture(texture, pixels); | ||
}; | ||
MathBackendWebGL.prototype.write = function (id, values, dtype, shape) { | ||
var texShape = webgl_util.getTextureShapeFromLogicalShape(this.gpgpu.gl, shape); | ||
var texture = this.textureManager.acquireTexture(texShape); | ||
var textureType = tex_util_1.TextureType.DEFAULT; | ||
this.texData[id] = { texture: texture, textureType: textureType, texShape: texShape, dtype: dtype }; | ||
if (values != null) { | ||
this.gpgpu.uploadMatrixToTexture(texture, texShape[0], texShape[1], typedArrayToFloat32(values, dtype)); | ||
if (values == null) { | ||
throw new Error('MathBackendWebGL.write(): values can not be null'); | ||
} | ||
var _a = this.getOrMakeTexData(id, shape, dtype), texture = _a.texture, texShape = _a.texShape; | ||
if (texture != null) { | ||
this.textureManager.releaseTexture(texture, texShape); | ||
this.texData[id].texture = null; | ||
} | ||
this.texData[id].values = values; | ||
if (!this.delayedStorage) { | ||
this.uploadToGPU(id); | ||
} | ||
}; | ||
MathBackendWebGL.prototype.getOrMakeTexData = function (id, shape, dtype) { | ||
if (!(id in this.texData)) { | ||
var texShape = webgl_util.getTextureShapeFromLogicalShape(this.gpgpu.gl, shape); | ||
var textureType = tex_util_1.TextureType.DEFAULT; | ||
this.texData[id] = | ||
{ texture: null, values: null, textureType: textureType, texShape: texShape, dtype: dtype }; | ||
} | ||
return this.texData[id]; | ||
}; | ||
MathBackendWebGL.prototype.readSync = function (id) { | ||
this.throwIfNoData(id); | ||
var values; | ||
var _a = this.texData[id], texture = _a.texture, textureType = _a.textureType, texShape = _a.texShape, numChannels = _a.numChannels, dtype = _a.dtype; | ||
var _a = this.texData[id], texture = _a.texture, values = _a.values, textureType = _a.textureType, texShape = _a.texShape, numChannels = _a.numChannels; | ||
if (values != null) { | ||
this.cacheOnCPU(id); | ||
return values; | ||
} | ||
var float32Values; | ||
if (textureType === tex_util_1.TextureType.DEFAULT) { | ||
values = this.gpgpu.downloadMatrixFromTexture(texture, texShape[0], texShape[1]); | ||
float32Values = this.gpgpu.downloadMatrixFromTexture(texture, texShape[0], texShape[1]); | ||
} | ||
else { | ||
values = this.gpgpu.downloadMatrixFromRGBAColorTexture(texture, texShape[0], texShape[1], numChannels); | ||
float32Values = this.gpgpu.downloadMatrixFromRGBAColorTexture(texture, texShape[0], texShape[1], numChannels); | ||
} | ||
return float32ToTypedArray(values, dtype); | ||
this.cacheOnCPU(id, float32Values); | ||
return this.texData[id].values; | ||
}; | ||
MathBackendWebGL.prototype.read = function (id) { | ||
return __awaiter(this, void 0, void 0, function () { | ||
var _a, texture, textureType, texShape; | ||
var _a, texture, values, textureType, texShape, float32Values; | ||
return __generator(this, function (_b) { | ||
@@ -147,7 +171,15 @@ switch (_b.label) { | ||
this.throwIfNoData(id); | ||
_a = this.texData[id], texture = _a.texture, textureType = _a.textureType, texShape = _a.texShape; | ||
if (environment_1.ENV.get('WEBGL_GET_BUFFER_SUB_DATA_ASYNC_EXTENSION_ENABLED') && | ||
textureType === tex_util_1.TextureType.DEFAULT) { | ||
return [2, this.gpgpu.downloadMatrixFromTextureAsync(texture, texShape[0], texShape[1])]; | ||
_a = this.texData[id], texture = _a.texture, values = _a.values, textureType = _a.textureType, texShape = _a.texShape; | ||
if (values != null) { | ||
this.cacheOnCPU(id); | ||
return [2, values]; | ||
} | ||
if (!(environment_1.ENV.get('WEBGL_GET_BUFFER_SUB_DATA_ASYNC_EXTENSION_ENABLED') && | ||
textureType === tex_util_1.TextureType.DEFAULT)) return [3, 2]; | ||
return [4, this.gpgpu.downloadMatrixFromTextureAsync(texture, texShape[0], texShape[1])]; | ||
case 1: | ||
float32Values = _b.sent(); | ||
this.cacheOnCPU(id, float32Values); | ||
return [2, this.texData[id].values]; | ||
case 2: | ||
if (!environment_1.ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED')) { | ||
@@ -157,3 +189,3 @@ return [2, this.readSync(id)]; | ||
return [4, this.gpgpu.runQuery(function () { })]; | ||
case 1: | ||
case 3: | ||
_b.sent(); | ||
@@ -186,3 +218,5 @@ return [2, this.readSync(id)]; | ||
var _a = this.texData[id], texture = _a.texture, texShape = _a.texShape; | ||
this.textureManager.releaseTexture(texture, texShape); | ||
if (texture != null) { | ||
this.textureManager.releaseTexture(texture, texShape); | ||
} | ||
delete this.texData[id]; | ||
@@ -192,7 +226,7 @@ } | ||
MathBackendWebGL.prototype.getTexture = function (id) { | ||
this.throwIfNoData(id); | ||
this.uploadToGPU(id); | ||
return this.texData[id].texture; | ||
}; | ||
MathBackendWebGL.prototype.getTextureData = function (id) { | ||
this.throwIfNoData(id); | ||
this.uploadToGPU(id); | ||
return this.texData[id]; | ||
@@ -262,3 +296,4 @@ }; | ||
var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.MUL, a.shape, b.shape); | ||
return this.compileAndRun(program, [a, b]); | ||
var output = this.makeOutputArray(program.outputShape, types.upcastType(a.dtype, b.dtype)); | ||
return this.compileAndRun(program, [a, b], output); | ||
}; | ||
@@ -394,7 +429,9 @@ MathBackendWebGL.prototype.batchNormalization2D = function (x, mean, variance, varianceEpsilon, scale, offset) { | ||
var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.ADD, a.shape, b.shape); | ||
return this.compileAndRun(program, [a, b]); | ||
var output = this.makeOutputArray(program.outputShape, types.upcastType(a.dtype, b.dtype)); | ||
return this.compileAndRun(program, [a, b], output); | ||
}; | ||
MathBackendWebGL.prototype.subtract = function (a, b) { | ||
var program = new binaryop_gpu_1.BinaryOpProgram(binaryop_gpu.SUB, a.shape, b.shape); | ||
return this.compileAndRun(program, [a, b]); | ||
var output = this.makeOutputArray(program.outputShape, types.upcastType(a.dtype, b.dtype)); | ||
return this.compileAndRun(program, [a, b], output); | ||
}; | ||
@@ -576,6 +613,7 @@ MathBackendWebGL.prototype.pow = function (a, b) { | ||
var inputsData = inputs.map(function (input) { | ||
_this.throwIfNoData(input.id); | ||
_this.uploadToGPU(input.id); | ||
return { array: input, texData: _this.texData[input.id] }; | ||
}); | ||
this.throwIfNoData(output.id); | ||
this.getOrMakeTexData(output.id, output.shape, output.dtype); | ||
this.uploadToGPU(output.id); | ||
var outputData = { array: output, texData: this.texData[output.id] }; | ||
@@ -615,2 +653,25 @@ var key = gpgpu_math.makeShaderKey(program, inputsData, outputData); | ||
}; | ||
MathBackendWebGL.prototype.uploadToGPU = function (id) { | ||
this.throwIfNoData(id); | ||
var _a = this.texData[id], texShape = _a.texShape, values = _a.values, texture = _a.texture, dtype = _a.dtype; | ||
if (texture != null) { | ||
return; | ||
} | ||
var newTexture = this.textureManager.acquireTexture(texShape); | ||
this.texData[id].texture = newTexture; | ||
if (values != null) { | ||
this.gpgpu.uploadMatrixToTexture(newTexture, texShape[0], texShape[1], typedArrayToFloat32(values, dtype)); | ||
} | ||
}; | ||
MathBackendWebGL.prototype.cacheOnCPU = function (id, float32Values) { | ||
var dontKeepCopyOnGPU = this.delayedStorage; | ||
var _a = this.texData[id], texture = _a.texture, texShape = _a.texShape, dtype = _a.dtype; | ||
if (dontKeepCopyOnGPU && texture != null) { | ||
this.textureManager.releaseTexture(texture, texShape); | ||
this.texData[id].texture = null; | ||
} | ||
if (float32Values != null) { | ||
this.texData[id].values = float32ToTypedArray(float32Values, dtype); | ||
} | ||
}; | ||
return MathBackendWebGL; | ||
@@ -617,0 +678,0 @@ }()); |
@@ -0,24 +1,29 @@ | ||
import { NamedArrayMap } from '../../util'; | ||
import { NDArray } from '../ndarray'; | ||
import { KernelConfigRegistry } from './kernel_registry'; | ||
export interface TapeNode { | ||
export declare type Tape = Array<TapeNode<TapeNodeOutput>>; | ||
export declare type TapeNodeOutput = NDArray | NamedArrayMap; | ||
export declare type TapeNodeType = 'kernel' | 'subtape'; | ||
export interface TapeNode<T extends TapeNodeOutput> { | ||
id: number; | ||
type: TapeNodeType; | ||
name: string; | ||
inputAndArgs: TapeNodeInputConfig; | ||
output: NDArray; | ||
gradient: (dy: NDArray, y: NDArray) => TapeNodeInputGradientArrays; | ||
output: T; | ||
gradient: (dy: T, y: T) => TapeNodeInputGradientArrays; | ||
subtape?: Tape; | ||
} | ||
export interface TapeNodeInputConfig { | ||
inputs: TapeNodeInputArrays; | ||
inputs: NamedArrayMap; | ||
} | ||
export declare type TapeNodeInputArrays = { | ||
[inputName: string]: NDArray; | ||
}; | ||
export declare type TapeNodeInputGradientArrays = { | ||
[inputName: string]: () => NDArray; | ||
}; | ||
export interface KernelNode extends TapeNode { | ||
export interface KernelNode extends TapeNode<NDArray> { | ||
kernel: keyof KernelConfigRegistry; | ||
inputAndArgs: KernelInputConfig; | ||
output: NDArray; | ||
} | ||
export interface KernelInputConfig extends TapeNodeInputConfig { | ||
inputs: TapeNodeInputArrays; | ||
inputs: NamedArrayMap; | ||
args?: { | ||
@@ -25,0 +30,0 @@ [argName: string]: any; |
import { NDArray } from '../ndarray'; | ||
import { MathBackend } from './backend'; | ||
import { TapeNode } from './tape_types'; | ||
export declare function getFilteredNodesXToY(tapeNodes: TapeNode[], xs: NDArray[], y: NDArray): TapeNode[]; | ||
export declare function backpropagateGradients(backend: MathBackend, arrayAccumulatedGradientMap: { | ||
import { Tape, TapeNodeInputConfig } from './tape_types'; | ||
export declare function getFilteredNodesXToY(tape: Tape, xs: NDArray[], y: NDArray): Tape; | ||
export declare function backpropagateGradients(arrayAccumulatedGradientMap: { | ||
[ndarrayId: number]: NDArray; | ||
}, filteredNodes: TapeNode[]): void; | ||
}, filteredTape: Tape): void; | ||
export declare function computeInputs(tape: Tape): { | ||
[idx: string]: NDArray; | ||
}; | ||
export declare type ScopeResultImmediate = void | NDArray | NDArray[] | { | ||
[key: string]: NDArray; | ||
}; | ||
export declare type ScopeResult = ScopeResultImmediate | Promise<ScopeResultImmediate>; | ||
export declare function extractNDArraysFromScopeResult(result: ScopeResultImmediate): NDArray[]; | ||
export declare function stripUndefinedInputsFromInputConfig(config: TapeNodeInputConfig): TapeNodeInputConfig; |
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
function getFilteredNodesXToY(tapeNodes, xs, y) { | ||
var environment_1 = require("../../environment"); | ||
var ndarray_1 = require("../ndarray"); | ||
function getFilteredNodesXToY(tape, xs, y) { | ||
var arraysFromX = {}; | ||
var nodesFromX = {}; | ||
for (var i = 0; i < xs.length; i++) { | ||
arraysFromX[xs[i].id] = true; | ||
} | ||
for (var i = 0; i < tapeNodes.length; i++) { | ||
var node = tapeNodes[i]; | ||
for (var i = 0; i < tape.length; i++) { | ||
var node = tape[i]; | ||
var nodeInputs = node.inputAndArgs.inputs; | ||
for (var inputName in nodeInputs) { | ||
var input = nodeInputs[inputName]; | ||
var anyInputFromX = false; | ||
for (var j = 0; j < xs.length; j++) { | ||
if (arraysFromX[input.id]) { | ||
arraysFromX[node.output.id] = true; | ||
if (node.output instanceof ndarray_1.NDArray) { | ||
arraysFromX[node.output.id] = true; | ||
} | ||
else { | ||
var keys = Object.keys(node.output); | ||
for (var _i = 0, keys_1 = keys; _i < keys_1.length; _i++) { | ||
var key = keys_1[_i]; | ||
arraysFromX[node.output[key].id] = true; | ||
} | ||
} | ||
anyInputFromX = true; | ||
nodesFromX[node.id] = true; | ||
break; | ||
} | ||
} | ||
if (arraysFromX[node.output.id]) { | ||
if (anyInputFromX) { | ||
break; | ||
@@ -26,15 +41,31 @@ } | ||
arraysLeadToY[y.id] = true; | ||
for (var i = tapeNodes.length - 1; i >= 0; i--) { | ||
var node = tapeNodes[i]; | ||
var nodesToY = {}; | ||
for (var i = tape.length - 1; i >= 0; i--) { | ||
var node = tape[i]; | ||
var nodeInputs = node.inputAndArgs.inputs; | ||
if (arraysLeadToY[node.output.id]) { | ||
for (var inputName in nodeInputs) { | ||
arraysLeadToY[nodeInputs[inputName].id] = true; | ||
var outputs = []; | ||
if (node.output instanceof ndarray_1.NDArray) { | ||
outputs.push(node.output); | ||
} | ||
else { | ||
var keys = Object.keys(node.output); | ||
for (var _a = 0, keys_2 = keys; _a < keys_2.length; _a++) { | ||
var key = keys_2[_a]; | ||
outputs.push(node.output[key]); | ||
} | ||
} | ||
for (var j = 0; j < outputs.length; j++) { | ||
if (arraysLeadToY[outputs[j].id]) { | ||
for (var inputName in nodeInputs) { | ||
arraysLeadToY[nodeInputs[inputName].id] = true; | ||
nodesToY[node.id] = true; | ||
} | ||
break; | ||
} | ||
} | ||
} | ||
var filteredTapeNodes = []; | ||
for (var i = 0; i < tapeNodes.length; i++) { | ||
var node = tapeNodes[i]; | ||
if (arraysFromX[node.output.id] && arraysLeadToY[node.output.id]) { | ||
var filteredTape = []; | ||
for (var i = 0; i < tape.length; i++) { | ||
var node = tape[i]; | ||
if (nodesFromX[node.id] && nodesToY[node.id]) { | ||
var prunedInputs = {}; | ||
@@ -47,14 +78,39 @@ for (var inputName in node.inputAndArgs.inputs) { | ||
} | ||
var prunedOutputs = void 0; | ||
if (node.output instanceof ndarray_1.NDArray) { | ||
prunedOutputs = node.output; | ||
} | ||
else { | ||
prunedOutputs = {}; | ||
for (var outputName in node.output) { | ||
var output = node.output[outputName]; | ||
if (arraysLeadToY[output.id]) { | ||
prunedOutputs[outputName] = node.output[outputName]; | ||
} | ||
} | ||
} | ||
var prunedNode = Object.assign({}, node); | ||
prunedNode.inputAndArgs = { inputs: prunedInputs }; | ||
filteredTapeNodes.push(prunedNode); | ||
prunedNode.output = prunedOutputs; | ||
filteredTape.push(prunedNode); | ||
} | ||
} | ||
return filteredTapeNodes; | ||
return filteredTape; | ||
} | ||
exports.getFilteredNodesXToY = getFilteredNodesXToY; | ||
function backpropagateGradients(backend, arrayAccumulatedGradientMap, filteredNodes) { | ||
for (var i = filteredNodes.length - 1; i >= 0; i--) { | ||
var node = filteredNodes[i]; | ||
var dy = arrayAccumulatedGradientMap[node.output.id]; | ||
function backpropagateGradients(arrayAccumulatedGradientMap, filteredTape) { | ||
for (var i = filteredTape.length - 1; i >= 0; i--) { | ||
var node = filteredTape[i]; | ||
var dy = void 0; | ||
if (node.output instanceof ndarray_1.NDArray) { | ||
dy = arrayAccumulatedGradientMap[node.output.id]; | ||
} | ||
else { | ||
dy = {}; | ||
var keys = Object.keys(node.output); | ||
for (var _i = 0, keys_3 = keys; _i < keys_3.length; _i++) { | ||
var key = keys_3[_i]; | ||
dy[key] = arrayAccumulatedGradientMap[node.output[key].id]; | ||
} | ||
} | ||
if (node.gradient == null) { | ||
@@ -78,3 +134,3 @@ throw new Error("Cannot compute gradient: gradient function not found for\n " + node.name + "."); | ||
arrayAccumulatedGradientMap[activation.id] = | ||
backend.add(curGradient, grad); | ||
environment_1.ENV.math.add(curGradient, grad); | ||
curGradient.dispose(); | ||
@@ -86,1 +142,62 @@ } | ||
exports.backpropagateGradients = backpropagateGradients; | ||
function computeInputs(tape) { | ||
var outputArrays = {}; | ||
for (var i = 0; i < tape.length; i++) { | ||
var node = tape[i]; | ||
if (node.output instanceof ndarray_1.NDArray) { | ||
outputArrays[node.output.id] = true; | ||
} | ||
else { | ||
var keys = Object.keys(node.output); | ||
for (var _i = 0, keys_4 = keys; _i < keys_4.length; _i++) { | ||
var key = keys_4[_i]; | ||
outputArrays[node.output[key].id] = true; | ||
} | ||
} | ||
} | ||
var inputArrays = {}; | ||
var inputArraysSeen = {}; | ||
var idx = 0; | ||
for (var i = 0; i < tape.length; i++) { | ||
var node = tape[i]; | ||
var inputs = node.inputAndArgs.inputs; | ||
var keys = Object.keys(inputs); | ||
for (var _a = 0, keys_5 = keys; _a < keys_5.length; _a++) { | ||
var key = keys_5[_a]; | ||
if (!outputArrays[inputs[key].id] && !inputArraysSeen[inputs[key].id]) { | ||
inputArrays[(idx++).toString()] = inputs[key]; | ||
inputArraysSeen[inputs[key].id] = true; | ||
} | ||
} | ||
} | ||
return inputArrays; | ||
} | ||
exports.computeInputs = computeInputs; | ||
function extractNDArraysFromScopeResult(result) { | ||
if (result == null) { | ||
return []; | ||
} | ||
if (result instanceof ndarray_1.NDArray) { | ||
return [result]; | ||
} | ||
var list = []; | ||
var resultObj = result; | ||
for (var k in resultObj) { | ||
var val = resultObj[k]; | ||
if (val instanceof ndarray_1.NDArray) { | ||
list.push(val); | ||
} | ||
} | ||
return list; | ||
} | ||
exports.extractNDArraysFromScopeResult = extractNDArraysFromScopeResult; | ||
function stripUndefinedInputsFromInputConfig(config) { | ||
var keys = Object.keys(config.inputs); | ||
keys.forEach(function (key) { | ||
if (config.inputs[key] == null) { | ||
delete config.inputs[key]; | ||
} | ||
}); | ||
return config; | ||
} | ||
exports.stripUndefinedInputsFromInputConfig = stripUndefinedInputsFromInputConfig; |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { NDArray } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface ArgMaxNode extends KernelNode { | ||
@@ -14,3 +15,3 @@ inputAndArgs: ArgMaxInputConfig; | ||
} | ||
export interface ArgMaxInputArrays extends TapeNodeInputArrays { | ||
export interface ArgMaxInputArrays extends NamedArrayMap { | ||
x: NDArray; | ||
@@ -32,3 +33,3 @@ } | ||
} | ||
export interface ArgMinInputArrays extends TapeNodeInputArrays { | ||
export interface ArgMinInputArrays extends NamedArrayMap { | ||
x: NDArray; | ||
@@ -35,0 +36,0 @@ } |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { Array1D, Array2D, Array3D } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface BatchNorm3DNode extends KernelNode { | ||
@@ -14,3 +15,3 @@ inputAndArgs: BatchNorm3DInputConfig; | ||
} | ||
export interface BatchNorm3DInputArrays extends TapeNodeInputArrays { | ||
export interface BatchNorm3DInputArrays extends NamedArrayMap { | ||
x: Array3D; | ||
@@ -40,3 +41,3 @@ mean: Array3D | Array1D; | ||
} | ||
export interface BatchNorm2DInputArrays extends TapeNodeInputArrays { | ||
export interface BatchNorm2DInputArrays extends NamedArrayMap { | ||
x: Array2D; | ||
@@ -43,0 +44,0 @@ mean: Array2D | Array1D; |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { NDArray } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface BinaryNode extends KernelNode { | ||
@@ -11,3 +12,3 @@ inputAndArgs: BinaryInputConfig; | ||
} | ||
export interface BinaryInputArrays extends TapeNodeInputArrays { | ||
export interface BinaryInputArrays extends NamedArrayMap { | ||
a: NDArray; | ||
@@ -14,0 +15,0 @@ b: NDArray; |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { Array1D, Array2D, Array3D, Array4D } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface Concat1DNode extends KernelNode { | ||
@@ -11,3 +12,3 @@ inputAndArgs: Concat1DInputConfig; | ||
} | ||
export interface Concat1DInputArrays extends TapeNodeInputArrays { | ||
export interface Concat1DInputArrays extends NamedArrayMap { | ||
a: Array1D; | ||
@@ -31,3 +32,3 @@ b: Array1D; | ||
} | ||
export interface Concat2DInputArrays extends TapeNodeInputArrays { | ||
export interface Concat2DInputArrays extends NamedArrayMap { | ||
a: Array2D; | ||
@@ -51,3 +52,3 @@ b: Array2D; | ||
} | ||
export interface Concat3DInputArrays extends TapeNodeInputArrays { | ||
export interface Concat3DInputArrays extends NamedArrayMap { | ||
a: Array3D; | ||
@@ -71,3 +72,3 @@ b: Array3D; | ||
} | ||
export interface Concat4DInputArrays extends TapeNodeInputArrays { | ||
export interface Concat4DInputArrays extends NamedArrayMap { | ||
a: Array4D; | ||
@@ -74,0 +75,0 @@ b: Array4D; |
@@ -0,4 +1,5 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { Conv2DInfo } from '../../conv_util'; | ||
import { Array1D, Array4D } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface Conv2DNode extends KernelNode { | ||
@@ -15,3 +16,3 @@ inputAndArgs: Conv2DInputConfig; | ||
} | ||
export interface Conv2DInputArrays extends TapeNodeInputArrays { | ||
export interface Conv2DInputArrays extends NamedArrayMap { | ||
x: Array4D; | ||
@@ -37,3 +38,3 @@ filter: Array4D; | ||
} | ||
export interface Conv2DDerInputInputArrays extends TapeNodeInputArrays { | ||
export interface Conv2DDerInputInputArrays extends NamedArrayMap { | ||
dy: Array4D; | ||
@@ -57,3 +58,3 @@ filter: Array4D; | ||
} | ||
export interface Conv2DDerFilterInputArrays extends TapeNodeInputArrays { | ||
export interface Conv2DDerFilterInputArrays extends NamedArrayMap { | ||
x: Array4D; | ||
@@ -74,3 +75,3 @@ dy: Array4D; | ||
} | ||
export interface Conv2DDerBiasInputArrays extends TapeNodeInputArrays { | ||
export interface Conv2DDerBiasInputArrays extends NamedArrayMap { | ||
dy: Array4D; | ||
@@ -92,3 +93,3 @@ } | ||
} | ||
export interface DepthwiseConv2DInputArrays extends TapeNodeInputArrays { | ||
export interface DepthwiseConv2DInputArrays extends NamedArrayMap { | ||
x: Array4D; | ||
@@ -95,0 +96,0 @@ filter: Array4D; |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { NDArray } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface EqualNode extends KernelNode { | ||
@@ -11,3 +12,3 @@ inputAndArgs: EqualInputConfig; | ||
} | ||
export interface EqualInputArrays extends TapeNodeInputArrays { | ||
export interface EqualInputArrays extends NamedArrayMap { | ||
a: NDArray; | ||
@@ -14,0 +15,0 @@ b: NDArray; |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { Array2D } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface MatMulNode extends KernelNode { | ||
@@ -15,3 +16,3 @@ inputAndArgs: MatMulInputConfig; | ||
} | ||
export interface MatMulInputArrays extends TapeNodeInputArrays { | ||
export interface MatMulInputArrays extends NamedArrayMap { | ||
a: Array2D; | ||
@@ -18,0 +19,0 @@ b: Array2D; |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { DataTypes, NDArray } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface MinNode<G extends keyof DataTypes> extends KernelNode { | ||
@@ -11,3 +12,3 @@ inputAndArgs: MinInputConfig<G>; | ||
} | ||
export interface MinInputArrays<G extends keyof DataTypes> extends TapeNodeInputArrays { | ||
export interface MinInputArrays<G extends keyof DataTypes> extends NamedArrayMap { | ||
x: NDArray<G>; | ||
@@ -26,3 +27,3 @@ } | ||
} | ||
export interface MaxInputArrays<G extends keyof DataTypes> extends TapeNodeInputArrays { | ||
export interface MaxInputArrays<G extends keyof DataTypes> extends NamedArrayMap { | ||
x: NDArray<G>; | ||
@@ -29,0 +30,0 @@ } |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { Array2D } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface MultinomialNode extends KernelNode { | ||
@@ -15,3 +16,3 @@ inputAndArgs: MultinomialInputConfig; | ||
} | ||
export interface MultinomialInputArrays extends TapeNodeInputArrays { | ||
export interface MultinomialInputArrays extends NamedArrayMap { | ||
probs: Array2D; | ||
@@ -18,0 +19,0 @@ } |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { Array1D, Array2D } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface OneHotNode extends KernelNode { | ||
@@ -16,3 +17,3 @@ inputAndArgs: OneHotInputConfig; | ||
} | ||
export interface OneHotInputArrays extends TapeNodeInputArrays { | ||
export interface OneHotInputArrays extends NamedArrayMap { | ||
indices: Array1D; | ||
@@ -19,0 +20,0 @@ } |
@@ -0,4 +1,5 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { Conv2DInfo } from '../../conv_util'; | ||
import { Array4D } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface PoolNode extends KernelNode { | ||
@@ -15,3 +16,3 @@ inputAndArgs: PoolInputConfig; | ||
} | ||
export interface PoolInputArrays extends TapeNodeInputArrays { | ||
export interface PoolInputArrays extends NamedArrayMap { | ||
x: Array4D; | ||
@@ -33,3 +34,3 @@ } | ||
} | ||
export interface PoolBackpropInputArrays extends TapeNodeInputArrays { | ||
export interface PoolBackpropInputArrays extends NamedArrayMap { | ||
dy: Array4D; | ||
@@ -36,0 +37,0 @@ x: Array4D; |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { NDArray } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface PowNode<T extends NDArray> extends KernelNode { | ||
@@ -11,3 +12,3 @@ inputAndArgs: PowInputConfig<T>; | ||
} | ||
export interface PowInputArrays<T extends NDArray> extends TapeNodeInputArrays { | ||
export interface PowInputArrays<T extends NDArray> extends NamedArrayMap { | ||
a: T; | ||
@@ -14,0 +15,0 @@ b: NDArray<'int32'>; |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { NDArray } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface PReLUNode<T extends NDArray> extends KernelNode { | ||
@@ -11,3 +12,3 @@ inputAndArgs: PReLUInputConfig<T>; | ||
} | ||
export interface PReLUInputArrays<T extends NDArray> extends TapeNodeInputArrays { | ||
export interface PReLUInputArrays<T extends NDArray> extends NamedArrayMap { | ||
x: T; | ||
@@ -14,0 +15,0 @@ alpha: T; |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { Array3D } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface ResizeBilinear3DNode extends KernelNode { | ||
@@ -15,3 +16,3 @@ inputAndArgs: ResizeBilinear3DInputConfig; | ||
} | ||
export interface ResizeBilinear3DInputArrays extends TapeNodeInputArrays { | ||
export interface ResizeBilinear3DInputArrays extends NamedArrayMap { | ||
x: Array3D; | ||
@@ -18,0 +19,0 @@ } |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { Array1D, Array2D, Array3D, Array4D } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface Slice1DNode extends KernelNode { | ||
@@ -15,3 +16,3 @@ inputAndArgs: Slice1DInputConfig; | ||
} | ||
export interface Slice1DInputArrays extends TapeNodeInputArrays { | ||
export interface Slice1DInputArrays extends NamedArrayMap { | ||
x: Array1D; | ||
@@ -34,3 +35,3 @@ } | ||
} | ||
export interface Slice2DInputArrays extends TapeNodeInputArrays { | ||
export interface Slice2DInputArrays extends NamedArrayMap { | ||
x: Array2D; | ||
@@ -53,3 +54,3 @@ } | ||
} | ||
export interface Slice3DInputArrays extends TapeNodeInputArrays { | ||
export interface Slice3DInputArrays extends NamedArrayMap { | ||
x: Array3D; | ||
@@ -72,3 +73,3 @@ } | ||
} | ||
export interface Slice4DInputArrays extends TapeNodeInputArrays { | ||
export interface Slice4DInputArrays extends NamedArrayMap { | ||
x: Array4D; | ||
@@ -75,0 +76,0 @@ } |
@@ -0,4 +1,5 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { DataTypes, NDArray } from '../../ndarray'; | ||
import { SumTypes } from '../../types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface SumNode<T extends keyof DataTypes> extends KernelNode { | ||
@@ -15,3 +16,3 @@ inputAndArgs: SumInputConfig<T>; | ||
} | ||
export interface SumInputArrays<T extends keyof DataTypes> extends TapeNodeInputArrays { | ||
export interface SumInputArrays<T extends keyof DataTypes> extends NamedArrayMap { | ||
x: NDArray<T>; | ||
@@ -18,0 +19,0 @@ } |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { Array1D, DataTypes, NDArray } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface TopKValuesNode<D extends keyof DataTypes, T extends NDArray<D>> extends KernelNode { | ||
@@ -14,3 +15,3 @@ inputAndArgs: TopKValuesInputConfig<T>; | ||
} | ||
export interface TopKValuesInputArrays<T extends NDArray> extends TapeNodeInputArrays { | ||
export interface TopKValuesInputArrays<T extends NDArray> extends NamedArrayMap { | ||
x: T; | ||
@@ -32,3 +33,3 @@ } | ||
} | ||
export interface TopKIndicesInputArrays extends TapeNodeInputArrays { | ||
export interface TopKIndicesInputArrays extends NamedArrayMap { | ||
x: NDArray; | ||
@@ -35,0 +36,0 @@ } |
@@ -0,3 +1,4 @@ | ||
import { NamedArrayMap } from '../../../util'; | ||
import { NDArray } from '../../ndarray'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays } from '../tape_types'; | ||
import { KernelInputConfig, KernelNode, TapeNodeInputGradientArrays } from '../tape_types'; | ||
export interface UnaryNode<T extends NDArray> extends KernelNode { | ||
@@ -11,3 +12,3 @@ inputAndArgs: UnaryInputConfig<T>; | ||
} | ||
export interface UnaryInputArrays<T extends NDArray> extends TapeNodeInputArrays { | ||
export interface UnaryInputArrays<T extends NDArray> extends NamedArrayMap { | ||
x: T; | ||
@@ -14,0 +15,0 @@ } |
@@ -12,2 +12,3 @@ import { DataTypes } from '../../ndarray'; | ||
numChannels?: number; | ||
values: DataTypes[keyof DataTypes]; | ||
} | ||
@@ -14,0 +15,0 @@ export declare function getUnpackedMatrixTextureShapeWidthHeight(rows: number, columns: number): [number, number]; |
@@ -8,2 +8,3 @@ import { GPGPUContext } from './gpgpu_context'; | ||
private logEnabled; | ||
private allocatedTextures; | ||
private usedTextureCount; | ||
@@ -10,0 +11,0 @@ constructor(gpgpu: GPGPUContext); |
@@ -10,2 +10,3 @@ "use strict"; | ||
this.logEnabled = false; | ||
this.allocatedTextures = []; | ||
this.usedTextureCount = {}; | ||
@@ -30,3 +31,5 @@ } | ||
this.log(); | ||
return this.gpgpu.createMatrixTexture(shapeRC[0], shapeRC[1]); | ||
var newTexture = this.gpgpu.createMatrixTexture(shapeRC[0], shapeRC[1]); | ||
this.allocatedTextures.push(newTexture); | ||
return newTexture; | ||
}; | ||
@@ -58,9 +61,11 @@ TextureManager.prototype.releaseTexture = function (texture, shape) { | ||
TextureManager.prototype.dispose = function () { | ||
for (var shape in this.freeTextures) { | ||
if (this.freeTextures.hasOwnProperty(shape)) { | ||
for (var i = 0; i < this.freeTextures[shape].length; i++) { | ||
this.gpgpu.deleteMatrixTexture(this.freeTextures[shape][i]); | ||
} | ||
} | ||
} | ||
var _this = this; | ||
this.allocatedTextures.forEach(function (texture) { | ||
_this.gpgpu.deleteMatrixTexture(texture); | ||
}); | ||
this.freeTextures = null; | ||
this.allocatedTextures = null; | ||
this.usedTextureCount = null; | ||
this.numUsedTextures = 0; | ||
this.numFreeTextures = 0; | ||
}; | ||
@@ -67,0 +72,0 @@ return TextureManager; |
import { BackendType } from '../environment'; | ||
import { NamedArrayMap } from '../util'; | ||
import { NDArrayStorage } from './backends/backend'; | ||
import { MathBackend } from './backends/backend'; | ||
import { BackendEngine } from './backends/backend_engine'; | ||
import { ScopeResult, ScopeResultImmediate } from './backends/tape_util'; | ||
import { MatrixOrientation } from './backends/types/matmul'; | ||
import { Array1D, Array2D, Array3D, Array4D, DataTypes, NDArray, Scalar } from './ndarray'; | ||
import { SumTypes } from './types'; | ||
export declare type ScopeResultImmediate = void | NDArray | NDArray[] | { | ||
[key: string]: NDArray; | ||
}; | ||
export declare type ScopeResult = ScopeResultImmediate | Promise<ScopeResultImmediate>; | ||
export interface LSTMCell { | ||
@@ -20,5 +18,4 @@ (data: Array2D, c: Array2D, h: Array2D): [Array2D, Array2D]; | ||
export declare class NDArrayMath implements NDArrayStorage, NDArrayManager { | ||
private safeMode; | ||
protected backendEngine: BackendEngine; | ||
private numArrays; | ||
private registeredArrays; | ||
private backend; | ||
@@ -33,13 +30,8 @@ private customBackend; | ||
read<T extends keyof DataTypes>(id: number): Promise<DataTypes[T]>; | ||
private ndarrayScopes; | ||
private activeScope; | ||
private ndarraysToKeep; | ||
private activeScopeNDArraysToKeep; | ||
constructor(backend: BackendType | MathBackend, safeMode: boolean); | ||
enableDebugMode(): void; | ||
scope<T extends ScopeResult>(scopeFn: (keep: <D1 extends keyof DataTypes, T1 extends NDArray<D1>>(ndarray: T1) => T1, track: <D2 extends keyof DataTypes, T2 extends NDArray<D2>>(ndarray: T2) => T2) => T): T; | ||
enableDebugMode(): void; | ||
gradientsScope<T extends ScopeResult>(scopeFn: (keep: <D1 extends keyof DataTypes, T1 extends NDArray<D1>>(ndarray: T1) => T1, track: <D2 extends keyof DataTypes, T2 extends NDArray<D2>>(ndarray: T2) => T2) => T): T; | ||
startScope(): void; | ||
private extractNDArraysFromScopeResult(result); | ||
endScope(result: ScopeResultImmediate): void; | ||
private isNDArrayDataInList(ndarray, ndarrayList); | ||
keep<T extends NDArray>(result: T): T; | ||
@@ -153,6 +145,8 @@ track<G extends keyof DataTypes, T extends NDArray<G>>(result: T): T; | ||
private normInternal<G>(x, p, axis?); | ||
gradientWrt<T extends NDArray | { | ||
[xName: string]: NDArray; | ||
}>(y: Scalar, x: T): T; | ||
gradients<T extends NDArray | NamedArrayMap>(f: () => Scalar, x: T): T; | ||
valueAndGradients<T extends NDArray | NamedArrayMap>(f: () => Scalar, x: T): { | ||
value: Scalar; | ||
gradients: T; | ||
}; | ||
disposeData(id: number): void; | ||
} |
@@ -15,8 +15,4 @@ "use strict"; | ||
function NDArrayMath(backend, safeMode) { | ||
this.safeMode = safeMode; | ||
this.numArrays = 0; | ||
this.registeredArrays = new Set(); | ||
this.customBackend = false; | ||
this.ndarrayScopes = []; | ||
this.ndarraysToKeep = []; | ||
this.activeScopeNDArraysToKeep = []; | ||
if (typeof backend === 'string') { | ||
@@ -29,3 +25,3 @@ this.backend = environment_1.ENV.getBackend(backend); | ||
} | ||
this.backendEngine = new backend_engine_1.BackendEngine(this.backend); | ||
this.backendEngine = new backend_engine_1.BackendEngine(this.backend, safeMode); | ||
} | ||
@@ -36,7 +32,10 @@ NDArrayMath.prototype.time = function (query) { | ||
NDArrayMath.prototype.getNumArrays = function () { | ||
return this.numArrays; | ||
return this.registeredArrays.size; | ||
}; | ||
NDArrayMath.prototype.register = function (a) { | ||
this.track(a); | ||
this.numArrays++; | ||
if (this.registeredArrays.has(a.id)) { | ||
throw new Error("NDArray with id " + a.id + " was already registered"); | ||
} | ||
this.registeredArrays.add(a.id); | ||
this.backendEngine.track(a); | ||
}; | ||
@@ -55,17 +54,2 @@ NDArrayMath.prototype.writePixels = function (id, pixels, numChannels) { | ||
}; | ||
NDArrayMath.prototype.scope = function (scopeFn) { | ||
var _this = this; | ||
this.startScope(); | ||
var keepFn = function (ndarray) { return _this.keep(ndarray); }; | ||
var trackFn = function (ndarray) { return ndarray; }; | ||
var result = scopeFn(keepFn, trackFn); | ||
if (result instanceof Promise) { | ||
result.then(function (r) { return _this.endScope(r); }); | ||
return result; | ||
} | ||
else { | ||
this.endScope(result); | ||
return result; | ||
} | ||
}; | ||
NDArrayMath.prototype.enableDebugMode = function () { | ||
@@ -77,86 +61,23 @@ this.backendEngine.enableDebugMode(); | ||
}; | ||
NDArrayMath.prototype.scope = function (scopeFn) { | ||
var gradientsMode = false; | ||
return this.backendEngine.scope('scope', scopeFn, gradientsMode); | ||
}; | ||
NDArrayMath.prototype.gradientsScope = function (scopeFn) { | ||
var gradientsMode = true; | ||
return this.backendEngine.scope('gradientsScope', scopeFn, gradientsMode); | ||
}; | ||
NDArrayMath.prototype.startScope = function () { | ||
var newScope = []; | ||
this.ndarrayScopes.push(newScope); | ||
this.activeScope = newScope; | ||
var newNDArraysToKeep = []; | ||
this.ndarraysToKeep.push(newNDArraysToKeep); | ||
this.activeScopeNDArraysToKeep = newNDArraysToKeep; | ||
var gradientsMode = false; | ||
this.backendEngine.startScope(gradientsMode); | ||
}; | ||
NDArrayMath.prototype.extractNDArraysFromScopeResult = function (result) { | ||
if (result == null) { | ||
return []; | ||
} | ||
if (result instanceof ndarray_1.NDArray) { | ||
return [result]; | ||
} | ||
var list = []; | ||
var resultObj = result; | ||
for (var k in resultObj) { | ||
var val = resultObj[k]; | ||
if (val instanceof ndarray_1.NDArray) { | ||
list.push(val); | ||
} | ||
} | ||
return list; | ||
}; | ||
NDArrayMath.prototype.endScope = function (result) { | ||
var _this = this; | ||
var arraysToKeep = this.activeScopeNDArraysToKeep; | ||
var resultArrays = this.extractNDArraysFromScopeResult(result); | ||
arraysToKeep = arraysToKeep.concat(resultArrays); | ||
for (var i = 0; i < this.activeScope.length; i++) { | ||
var ndarray = this.activeScope[i]; | ||
if (this.isNDArrayDataInList(ndarray, arraysToKeep)) { | ||
continue; | ||
} | ||
ndarray.dispose(); | ||
} | ||
this.ndarrayScopes.pop(); | ||
this.activeScope = this.ndarrayScopes.length === 0 ? | ||
null : | ||
this.ndarrayScopes[this.ndarrayScopes.length - 1]; | ||
resultArrays.forEach(function (val) { | ||
if (!_this.isNDArrayDataInList(val, _this.activeScopeNDArraysToKeep)) { | ||
_this.track(val); | ||
} | ||
}); | ||
this.ndarraysToKeep.pop(); | ||
this.activeScopeNDArraysToKeep = this.ndarraysToKeep.length === 0 ? | ||
null : | ||
this.ndarraysToKeep[this.ndarraysToKeep.length - 1]; | ||
var gradientsMode = false; | ||
this.backendEngine.endScope(result, gradientsMode); | ||
}; | ||
NDArrayMath.prototype.isNDArrayDataInList = function (ndarray, ndarrayList) { | ||
for (var i = 0; i < ndarrayList.length; i++) { | ||
if (ndarrayList[i].id === ndarray.id) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
}; | ||
NDArrayMath.prototype.keep = function (result) { | ||
if (this.activeScope == null) { | ||
if (this.safeMode) { | ||
throw new Error('You are using math in safe mode. Enclose all ' + | ||
'math.method() calls inside a scope: ' + | ||
'math.scope(() => {math.method();...}) to avoid memory ' + | ||
'leaks.'); | ||
} | ||
return result; | ||
} | ||
this.activeScopeNDArraysToKeep.push(result); | ||
return result; | ||
return this.backendEngine.keep(result); | ||
}; | ||
NDArrayMath.prototype.track = function (result) { | ||
if (this.activeScope == null) { | ||
if (this.safeMode) { | ||
throw new Error('You are using math in safe mode. Enclose all ' + | ||
'math.method() calls inside a scope: ' + | ||
'math.scope(() => {math.method();...}) to avoid memory ' + | ||
'leaks.'); | ||
} | ||
return result; | ||
} | ||
this.activeScope.push(result); | ||
return result; | ||
return this.backendEngine.track(result); | ||
}; | ||
@@ -181,2 +102,6 @@ NDArrayMath.prototype.dispose = function () { | ||
return this.backendEngine.executeKernel('MatMul', { inputs: { a: a, b: b }, args: { aOrientation: aOrientation, bOrientation: bOrientation } }, function (dy, y) { | ||
if (aOrientation === matmul_1.MatrixOrientation.TRANSPOSED || | ||
bOrientation === matmul_1.MatrixOrientation.TRANSPOSED) { | ||
throw new Error("Backprop for transposed MatMul not yet implemented."); | ||
} | ||
return { | ||
@@ -481,9 +406,35 @@ a: function () { return _this.matMul(dy, b, matmul_1.MatrixOrientation.REGULAR, matmul_1.MatrixOrientation.TRANSPOSED); }, | ||
NDArrayMath.prototype.subtract = function (a, b) { | ||
var _this = this; | ||
broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); | ||
return this.backendEngine.executeKernel('Sub', { inputs: { a: a, b: b } }); | ||
return this.backendEngine.executeKernel('Sub', { inputs: { a: a, b: b } }, function (dy, y) { | ||
if (!util.arraysEqual(a.shape, b.shape)) { | ||
throw new Error("Backprop through broadcasted subtract not " + | ||
"yet supported."); | ||
} | ||
return { | ||
a: function () { return ndarray_1.NDArray.onesLike(a); }, | ||
b: function () { return _this.scope(function () { return _this.neg(ndarray_1.NDArray.onesLike(b)); }); } | ||
}; | ||
}); | ||
}; | ||
NDArrayMath.prototype.pow = function (a, b) { | ||
var _this = this; | ||
util.assert(b.dtype === 'int32', 'only supports int32 data type for the exponent parameter.'); | ||
broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); | ||
return this.backendEngine.executeKernel('Pow', { inputs: { a: a, b: b } }); | ||
var gradient = function (dy, y) { | ||
if (!util.arraysEqual(a.shape, b.shape)) { | ||
throw new Error("Gradient of pow not yet supported for broadcasted shapes."); | ||
} | ||
var derA = function () { | ||
return _this.scope(function () { | ||
return _this.multiply(dy, _this.multiply(b, _this.pow(a, _this.subtract(b, ndarray_1.Scalar.new(1, 'int32'))))); | ||
}); | ||
}; | ||
var derB = function () { | ||
throw new Error("Backprop through exponent of math.pow not " + | ||
"implemented yet."); | ||
}; | ||
return { a: derA, b: derB }; | ||
}; | ||
return this.backendEngine.executeKernel('Pow', { inputs: { a: a, b: b } }, gradient); | ||
}; | ||
@@ -502,4 +453,10 @@ NDArrayMath.prototype.powStrict = function (a, b) { | ||
NDArrayMath.prototype.multiply = function (a, b) { | ||
var _this = this; | ||
broadcast_util.assertAndGetBroadcastShape(a.shape, b.shape); | ||
return this.backendEngine.executeKernel('Mul', { inputs: { a: a, b: b } }); | ||
return this.backendEngine.executeKernel('Mul', { inputs: { a: a, b: b } }, function (dy, y) { | ||
if (!util.arraysEqual(a.shape, b.shape)) { | ||
throw new Error("Backprop through broadcasted multiply not supported yet."); | ||
} | ||
return { a: function () { return _this.clone(b); }, b: function () { return _this.clone(a); } }; | ||
}); | ||
}; | ||
@@ -547,3 +504,8 @@ NDArrayMath.prototype.elementWiseMul = function (a, b) { | ||
NDArrayMath.prototype.square = function (x) { | ||
return this.backendEngine.executeKernel('Square', { inputs: { x: x } }); | ||
var _this = this; | ||
return this.backendEngine.executeKernel('Square', { inputs: { x: x } }, function (dy, y) { | ||
return { | ||
x: function () { return _this.multiply(dy, _this.multiply(x, ndarray_1.Scalar.new(2))); } | ||
}; | ||
}); | ||
}; | ||
@@ -1057,31 +1019,34 @@ NDArrayMath.prototype.abs = function (x) { | ||
}; | ||
NDArrayMath.prototype.gradientWrt = function (y, x) { | ||
var xIsArray = x instanceof ndarray_1.NDArray; | ||
var xs = []; | ||
var xKeys; | ||
if (xIsArray) { | ||
xs.push(x); | ||
NDArrayMath.prototype.gradients = function (f, x) { | ||
var keys = x instanceof ndarray_1.NDArray ? null : Object.keys(x); | ||
var xs = util.flattenNameArrayMap(x, keys); | ||
var returnValue = false; | ||
var gradients = this.backendEngine.gradients(f, xs, returnValue); | ||
if (x instanceof ndarray_1.NDArray) { | ||
return gradients[0]; | ||
} | ||
else { | ||
var xMap = x; | ||
xKeys = Object.keys(xMap); | ||
for (var i = 0; i < xKeys.length; i++) { | ||
xs.push(xMap[xKeys[i]]); | ||
} | ||
return util.unflattenToNameArrayMap(keys, gradients); | ||
} | ||
var gradients = this.backendEngine.gradientWrt(y, xs); | ||
if (xIsArray) { | ||
return gradients[0]; | ||
}; | ||
NDArrayMath.prototype.valueAndGradients = function (f, x) { | ||
var keys = x instanceof ndarray_1.NDArray ? null : Object.keys(x); | ||
var xs = util.flattenNameArrayMap(x, keys); | ||
var returnValue = true; | ||
var valueAndGradients = this.backendEngine.gradients(f, xs, returnValue); | ||
var gradients; | ||
if (x instanceof ndarray_1.NDArray) { | ||
gradients = valueAndGradients.gradients[0]; | ||
} | ||
else { | ||
var result = {}; | ||
for (var i = 0; i < xKeys.length; i++) { | ||
result[xKeys[i]] = gradients[i]; | ||
} | ||
return result; | ||
gradients = | ||
util.unflattenToNameArrayMap(keys, valueAndGradients.gradients); | ||
} | ||
return { value: valueAndGradients.value, gradients: gradients }; | ||
}; | ||
NDArrayMath.prototype.disposeData = function (id) { | ||
this.backend.disposeData(id); | ||
this.numArrays--; | ||
if (this.registeredArrays.has(id)) { | ||
this.registeredArrays.delete(id); | ||
this.backend.disposeData(id); | ||
} | ||
}; | ||
@@ -1088,0 +1053,0 @@ return NDArrayMath; |
@@ -83,2 +83,4 @@ "use strict"; | ||
this.math.register(this); | ||
} | ||
if (values != null) { | ||
this.math.write(this.id, values, this.dtype, this.shape); | ||
@@ -103,3 +105,3 @@ } | ||
var newValues = copyTypedArray(another.getValues(), another.dtype); | ||
return NDArray.make(another.shape, { values: newValues }, another.dtype); | ||
return NDArray.make(another.shape, { values: newValues }, another.dtype, another.math); | ||
}; | ||
@@ -129,4 +131,4 @@ NDArray.make = function (shape, data, dtype, math) { | ||
var shape = [pixels.height, pixels.width, numChannels]; | ||
var res = NDArray.make(shape, ndarrayData, 'int32'); | ||
math = math || environment_1.ENV.math; | ||
var res = NDArray.make(shape, ndarrayData, 'int32', math); | ||
math.writePixels(res.id, pixels, numChannels); | ||
@@ -143,3 +145,3 @@ return res; | ||
util.assert(this.size === util.sizeFromShape(newShape), 'new shape and old shape must have the same number of elements.'); | ||
return NDArray.make(newShape, data, this.dtype); | ||
return NDArray.make(newShape, data, this.dtype, this.math); | ||
}; | ||
@@ -181,3 +183,3 @@ NDArray.prototype.flatten = function () { | ||
var newVals = toTypedArray(vals, dtype); | ||
return NDArray.make(this.shape, { values: newVals }, dtype); | ||
return NDArray.make(this.shape, { values: newVals }, dtype, this.math); | ||
}; | ||
@@ -223,3 +225,2 @@ Object.defineProperty(NDArray.prototype, "rank", { | ||
vals[index] = value; | ||
this.math.disposeData(this.id); | ||
this.math.write(this.id, vals, this.dtype, this.shape); | ||
@@ -267,3 +268,2 @@ }; | ||
vals.fill(value); | ||
this.math.disposeData(this.id); | ||
this.math.write(this.id, vals, this.dtype, this.shape); | ||
@@ -270,0 +270,0 @@ }; |
import { Features } from './environment'; | ||
import { NDArrayMath } from './math/math'; | ||
import { NDArray } from './math/ndarray'; | ||
import { DType, TypedArray } from './util'; | ||
@@ -11,3 +12,3 @@ export declare const TEST_EPSILON = 0.01; | ||
export declare function expectArrayInMeanStdRange(actual: TypedArray | number[], expectedMean: number, expectedStdDev: number, epsilon?: number): void; | ||
export declare function expectArraysClose(actual: TypedArray | number[], expected: TypedArray | number[], epsilon?: number): void; | ||
export declare function expectArraysClose(actual: NDArray | TypedArray | number[], expected: NDArray | TypedArray | number[], epsilon?: number): void; | ||
export declare function expectNumbersClose(a: number, e: number, epsilon?: number): void; | ||
@@ -14,0 +15,0 @@ export declare function expectValuesInRange(actual: TypedArray | number[], low: number, high: number): void; |
@@ -7,2 +7,3 @@ "use strict"; | ||
var math_1 = require("./math/math"); | ||
var ndarray_1 = require("./math/ndarray"); | ||
var util = require("./util"); | ||
@@ -73,14 +74,41 @@ exports.TEST_EPSILON = 1e-2; | ||
if (epsilon === void 0) { epsilon = exports.TEST_EPSILON; } | ||
var aType = actual.constructor.name; | ||
var bType = expected.constructor.name; | ||
if (aType !== bType) { | ||
throw new Error("Arrays are of different type " + aType + " vs " + bType); | ||
if (!(actual instanceof ndarray_1.NDArray) && !(expected instanceof ndarray_1.NDArray)) { | ||
var aType = actual.constructor.name; | ||
var bType = expected.constructor.name; | ||
if (aType !== bType) { | ||
throw new Error("Arrays are of different type actual: " + aType + " " + | ||
("vs expected: " + bType)); | ||
} | ||
} | ||
if (actual.length !== expected.length) { | ||
throw new Error("Matrices have different lengths (" + actual.length + " vs " + | ||
(expected.length + ").")); | ||
else if (actual instanceof ndarray_1.NDArray && expected instanceof ndarray_1.NDArray) { | ||
if (actual.dtype !== expected.dtype) { | ||
throw new Error("Arrays are of different type actual: " + actual.dtype + " " + | ||
("vs expected: " + expected.dtype + ".")); | ||
} | ||
if (!util.arraysEqual(actual.shape, expected.shape)) { | ||
throw new Error("Arrays are of different shape actual: " + actual.shape + " " + | ||
("vs expected: " + expected.shape + ".")); | ||
} | ||
} | ||
for (var i = 0; i < expected.length; ++i) { | ||
var a = actual[i]; | ||
var e = expected[i]; | ||
var actualValues; | ||
var expectedValues; | ||
if (actual instanceof ndarray_1.NDArray) { | ||
actualValues = actual.dataSync(); | ||
} | ||
else { | ||
actualValues = actual; | ||
} | ||
if (expected instanceof ndarray_1.NDArray) { | ||
expectedValues = expected.dataSync(); | ||
} | ||
else { | ||
expectedValues = expected; | ||
} | ||
if (actualValues.length !== expectedValues.length) { | ||
throw new Error("Arrays have different lengths actual: " + actualValues.length + " vs " + | ||
("expected: " + expectedValues.length + ".")); | ||
} | ||
for (var i = 0; i < expectedValues.length; ++i) { | ||
var a = actualValues[i]; | ||
var e = expectedValues[i]; | ||
if (!areClose(a, e, epsilon)) { | ||
@@ -87,0 +115,0 @@ var actualStr = "actual[" + i + "] === " + a; |
@@ -1,2 +0,2 @@ | ||
import { DataTypes } from './math/ndarray'; | ||
import { DataTypes, NDArray } from './math/ndarray'; | ||
export declare type TypedArray = Float32Array | Int32Array | Uint8Array; | ||
@@ -6,2 +6,5 @@ export declare type FlatVector = boolean[] | number[] | TypedArray; | ||
export declare type ArrayData = TypedArray | RegularArray<number> | RegularArray<boolean>; | ||
export declare type NamedArrayMap = { | ||
[name: string]: NDArray; | ||
}; | ||
export declare function shuffle(array: any[] | Uint32Array | Int32Array | Float32Array): void; | ||
@@ -39,1 +42,5 @@ export declare function clamp(min: number, x: number, max: number): number; | ||
export declare function getTypedArrayFromDType<D extends keyof DataTypes>(dtype: D, size: number): DataTypes[D]; | ||
export declare function isNDArrayInList(ndarray: NDArray, ndarrayList: NDArray[]): boolean; | ||
export declare function checkForNaN(vals: TypedArray, dtype: keyof DataTypes, name: string): void; | ||
export declare function flattenNameArrayMap(nameArrayMap: NDArray | NamedArrayMap, keys?: string[]): NDArray[]; | ||
export declare function unflattenToNameArrayMap(keys: string[], flatArrays: NDArray[]): NamedArrayMap; |
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
var ndarray_1 = require("./math/ndarray"); | ||
function shuffle(array) { | ||
@@ -273,1 +274,43 @@ var counter = array.length; | ||
exports.getTypedArrayFromDType = getTypedArrayFromDType; | ||
function isNDArrayInList(ndarray, ndarrayList) { | ||
for (var i = 0; i < ndarrayList.length; i++) { | ||
if (ndarrayList[i].id === ndarray.id) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
exports.isNDArrayInList = isNDArrayInList; | ||
function checkForNaN(vals, dtype, name) { | ||
for (var i = 0; i < vals.length; i++) { | ||
if (isValNaN(vals[i], dtype)) { | ||
throw Error("The result of the last math." + name + " has NaNs."); | ||
} | ||
} | ||
} | ||
exports.checkForNaN = checkForNaN; | ||
function flattenNameArrayMap(nameArrayMap, keys) { | ||
var xs = []; | ||
if (nameArrayMap instanceof ndarray_1.NDArray) { | ||
xs.push(nameArrayMap); | ||
} | ||
else { | ||
var xMap = nameArrayMap; | ||
for (var i = 0; i < keys.length; i++) { | ||
xs.push(xMap[keys[i]]); | ||
} | ||
} | ||
return xs; | ||
} | ||
exports.flattenNameArrayMap = flattenNameArrayMap; | ||
function unflattenToNameArrayMap(keys, flatArrays) { | ||
if (keys.length !== flatArrays.length) { | ||
throw new Error("Cannot unflatten NDArray[], keys and arrays are not of same length."); | ||
} | ||
var result = {}; | ||
for (var i = 0; i < keys.length; i++) { | ||
result[keys[i]] = flatArrays[i]; | ||
} | ||
return result; | ||
} | ||
exports.unflattenToNameArrayMap = unflattenToNameArrayMap; |
@@ -1,2 +0,2 @@ | ||
declare const version = "0.3.15"; | ||
declare const version = "0.3.16"; | ||
export { version }; |
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
var version = '0.3.15'; | ||
var version = '0.3.16'; | ||
exports.version = version; |
{ | ||
"name": "deeplearn", | ||
"version": "0.3.15", | ||
"version": "0.3.16", | ||
"description": "Hardware-accelerated JavaScript library for machine intelligence", | ||
@@ -5,0 +5,0 @@ "private": false, |
@@ -28,5 +28,5 @@ <a id="travis-badge" href="https://travis-ci.org/PAIR-code/deeplearnjs" alt="Build Status"> | ||
```ts | ||
import {Array1D, NDArrayMathGPU, Scalar} from 'deeplearn'; | ||
import {Array1D, ENV, Scalar} from 'deeplearn'; | ||
const math = new NDArrayMathGPU(); | ||
const math = ENV.math; | ||
const a = Array1D.new([1, 2, 3]); | ||
@@ -67,3 +67,3 @@ const b = Scalar.new(2); | ||
```js | ||
var math = new dl.NDArrayMathGPU(); | ||
var math = dl.ENV.math; | ||
var a = dl.Array1D.new([1, 2, 3]); | ||
@@ -70,0 +70,0 @@ var b = dl.Scalar.new(2); |
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
1765730
28622
240