@magenta/music-vae
Advanced tools
Comparing version
@@ -6,3 +6,5 @@ import * as dl from 'deeplearn'; | ||
type: string; | ||
args: DrumsConverterArgs | MelodyConverterArgs; | ||
args: { | ||
[argName: string]: any; | ||
}; | ||
} | ||
@@ -12,14 +14,7 @@ export declare function converterFromSpec(spec: ConverterSpec): MelodyConverter | DrumsConverter; | ||
abstract numSteps: number; | ||
abstract numSegments: number; | ||
abstract toTensor(noteSequence: INoteSequence): dl.Tensor2D; | ||
abstract toNoteSequence(tensor: dl.Tensor2D): INoteSequence; | ||
} | ||
export interface DrumsConverterArgs { | ||
numSteps: number; | ||
numSegments?: number; | ||
pitchClasses?: number[][]; | ||
} | ||
export declare class DrumsConverter extends DataConverter { | ||
numSteps: number; | ||
numSegments: number; | ||
pitchClasses: number[][]; | ||
@@ -29,3 +24,3 @@ pitchToClass: { | ||
}; | ||
constructor(args: DrumsConverterArgs); | ||
constructor(numSteps: number, pitchClasses?: number[][]); | ||
toTensor(noteSequence: INoteSequence): dl.Tensor<dl.Rank.R2>; | ||
@@ -37,11 +32,4 @@ toNoteSequence(oh: dl.Tensor2D): NoteSequence; | ||
} | ||
export interface MelodyConverterArgs { | ||
numSteps: number; | ||
minPitch: number; | ||
maxPitch: number; | ||
numSegments?: number; | ||
} | ||
export declare class MelodyConverter extends DataConverter { | ||
numSteps: number; | ||
numSegments: number; | ||
minPitch: number; | ||
@@ -52,5 +40,5 @@ maxPitch: number; | ||
FIRST_PITCH: number; | ||
constructor(args: MelodyConverterArgs); | ||
constructor(numSteps: number, minPitch: number, maxPitch: number); | ||
toTensor(noteSequence: INoteSequence): dl.Tensor<dl.Rank.R2>; | ||
toNoteSequence(oh: dl.Tensor2D): NoteSequence; | ||
} |
@@ -28,9 +28,9 @@ "use strict"; | ||
if (spec.type === 'MelodyConverter') { | ||
return new MelodyConverter(spec.args); | ||
return new MelodyConverter(spec.args.numSteps, spec.args.minPitch, spec.args.maxPitch); | ||
} | ||
else if (spec.type === 'DrumsConverter') { | ||
return new DrumsConverter(spec.args); | ||
return new DrumsConverter(spec.args.numSteps, spec.args.pitchClasses); | ||
} | ||
else if (spec.type === 'DrumRollConverter') { | ||
return new DrumRollConverter(spec.args); | ||
return new DrumRollConverter(spec.args.numSteps, spec.args.pitchClasses); | ||
} | ||
@@ -50,14 +50,12 @@ else { | ||
__extends(DrumsConverter, _super); | ||
function DrumsConverter(args) { | ||
function DrumsConverter(numSteps, pitchClasses) { | ||
var _this = _super.call(this) || this; | ||
_this.pitchClasses = (args.pitchClasses) ? | ||
args.pitchClasses : DEFAULT_DRUM_PITCH_CLASSES; | ||
_this.numSteps = args.numSteps; | ||
_this.numSegments = args.numSegments; | ||
pitchClasses = (pitchClasses) ? pitchClasses : DEFAULT_DRUM_PITCH_CLASSES; | ||
_this.numSteps = numSteps; | ||
_this.pitchClasses = pitchClasses; | ||
_this.pitchToClass = {}; | ||
var _loop_1 = function (c) { | ||
this_1.pitchClasses[c].forEach(function (p) { _this.pitchToClass[p] = c; }); | ||
pitchClasses[c].forEach(function (p) { _this.pitchToClass[p] = c; }); | ||
}; | ||
var this_1 = this; | ||
for (var c = 0; c < _this.pitchClasses.length; ++c) { | ||
for (var c = 0; c < pitchClasses.length; ++c) { | ||
_loop_1(c); | ||
@@ -128,11 +126,10 @@ } | ||
__extends(MelodyConverter, _super); | ||
function MelodyConverter(args) { | ||
function MelodyConverter(numSteps, minPitch, maxPitch) { | ||
var _this = _super.call(this) || this; | ||
_this.NOTE_OFF = 1; | ||
_this.FIRST_PITCH = 2; | ||
_this.numSteps = args.numSteps; | ||
_this.numSegments = args.numSegments; | ||
_this.minPitch = args.minPitch; | ||
_this.maxPitch = args.maxPitch; | ||
_this.depth = args.maxPitch - args.minPitch + 3; | ||
_this.numSteps = numSteps; | ||
_this.minPitch = minPitch; | ||
_this.maxPitch = maxPitch; | ||
_this.depth = maxPitch - minPitch + 3; | ||
return _this; | ||
@@ -139,0 +136,0 @@ } |
@@ -8,11 +8,2 @@ import * as dl from 'deeplearn'; | ||
} | ||
declare abstract class Encoder { | ||
abstract zDims: number; | ||
abstract encode(sequence: dl.Tensor3D): dl.Tensor2D; | ||
} | ||
declare abstract class Decoder { | ||
abstract outputDims: number; | ||
abstract zDims: number; | ||
abstract decode(z: dl.Tensor2D, length: number, initialInput?: dl.Tensor2D, temperature?: number): dl.Tensor3D; | ||
} | ||
declare class Nade { | ||
@@ -26,11 +17,31 @@ encWeights: dl.Tensor2D; | ||
} | ||
declare class Encoder { | ||
lstmFwVars: LayerVars; | ||
lstmBwVars: LayerVars; | ||
muVars: LayerVars; | ||
zDims: number; | ||
constructor(lstmFwVars: LayerVars, lstmBwVars: LayerVars, muVars: LayerVars); | ||
encode(sequence: dl.Tensor3D): dl.Tensor2D; | ||
private runLstm(inputs, lstmVars, reverse); | ||
} | ||
declare class Decoder { | ||
lstmCellVars: LayerVars[]; | ||
zToInitStateVars: LayerVars; | ||
outputProjectVars: LayerVars; | ||
zDims: number; | ||
outputDims: number; | ||
nade: Nade; | ||
constructor(lstmCellVars: LayerVars[], zToInitStateVars: LayerVars, outputProjectVars: LayerVars, nade?: Nade); | ||
decode(z: dl.Tensor2D, length: number, temperature?: number): dl.Tensor<dl.Rank.R3>; | ||
} | ||
declare class MusicVAE { | ||
private checkpointURL; | ||
private dataConverter; | ||
private encoder; | ||
private decoder; | ||
private rawVars; | ||
checkpointURL: string; | ||
dataConverter: data.DataConverter; | ||
encoder: Encoder; | ||
decoder: Decoder; | ||
rawVars: { | ||
[varName: string]: dl.Tensor; | ||
}; | ||
constructor(checkpointURL: string, dataConverter?: data.DataConverter); | ||
dispose(): void; | ||
private getLstmLayers(cellFormat, vars); | ||
initialize(): Promise<this>; | ||
@@ -37,0 +48,0 @@ isInitialized(): boolean; |
338
es5/model.js
"use strict"; | ||
var __extends = (this && this.__extends) || (function () { | ||
var extendStatics = Object.setPrototypeOf || | ||
({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) || | ||
function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; }; | ||
return function (d, b) { | ||
extendStatics(d, b); | ||
function __() { this.constructor = d; } | ||
d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); | ||
}; | ||
})(); | ||
var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { | ||
@@ -51,3 +41,3 @@ return new (P || (P = Promise))(function (resolve, reject) { | ||
var checkpoint_loader_1 = require("./checkpoint_loader"); | ||
var forgetBias = dl.scalar(1.0); | ||
var DECODER_CELL_FORMAT = "decoder/multi_rnn_cell/cell_%d/lstm_cell/"; | ||
var LayerVars = (function () { | ||
@@ -64,37 +54,55 @@ function LayerVars(kernel, bias) { | ||
} | ||
var Encoder = (function () { | ||
function Encoder() { | ||
var Nade = (function () { | ||
function Nade(encWeights, decWeightsT) { | ||
this.numDims = encWeights.shape[0]; | ||
this.numHidden = encWeights.shape[2]; | ||
this.encWeights = encWeights.as2D(this.numDims, this.numHidden); | ||
this.decWeightsT = decWeightsT.as2D(this.numDims, this.numHidden); | ||
} | ||
return Encoder; | ||
Nade.prototype.sample = function (encBias, decBias) { | ||
var _this = this; | ||
var batchSize = encBias.shape[0]; | ||
return dl.tidy(function () { | ||
var samples = []; | ||
var a = encBias.clone(); | ||
for (var i = 0; i < _this.numDims; i++) { | ||
var h = dl.sigmoid(a); | ||
var encWeightsI = _this.encWeights.slice([i, 0], [1, _this.numHidden]).as1D(); | ||
var decWeightsTI = _this.decWeightsT.slice([i, 0], [1, _this.numHidden]); | ||
var decBiasI = decBias.slice([0, i], [batchSize, 1]); | ||
var condLogitsI = decBiasI.add(dl.matMul(h, decWeightsTI, false, true)); | ||
var condProbsI = condLogitsI.sigmoid(); | ||
var samplesI = condProbsI.greaterEqual(dl.scalar(0.5)).toFloat().as1D(); | ||
if (i < _this.numDims - 1) { | ||
a = a.add(dl.outerProduct(samplesI.toFloat(), encWeightsI)); | ||
} | ||
samples.push(samplesI); | ||
} | ||
return dl.stack(samples, 1); | ||
}); | ||
}; | ||
return Nade; | ||
}()); | ||
exports.Encoder = Encoder; | ||
var BidirectonalLstmEncoder = (function (_super) { | ||
__extends(BidirectonalLstmEncoder, _super); | ||
function BidirectonalLstmEncoder(lstmFwVars, lstmBwVars, muVars) { | ||
var _this = _super.call(this) || this; | ||
_this.lstmFwVars = lstmFwVars; | ||
_this.lstmBwVars = lstmBwVars; | ||
_this.muVars = muVars; | ||
_this.zDims = muVars ? _this.muVars.bias.shape[0] : null; | ||
return _this; | ||
exports.Nade = Nade; | ||
var Encoder = (function () { | ||
function Encoder(lstmFwVars, lstmBwVars, muVars) { | ||
this.lstmFwVars = lstmFwVars; | ||
this.lstmBwVars = lstmBwVars; | ||
this.muVars = muVars; | ||
this.zDims = this.muVars.bias.shape[0]; | ||
} | ||
BidirectonalLstmEncoder.prototype.encode = function (sequence) { | ||
Encoder.prototype.encode = function (sequence) { | ||
var _this = this; | ||
return dl.tidy(function () { | ||
var fwState = _this.singleDirection(sequence, true); | ||
var bwState = _this.singleDirection(sequence, false); | ||
var fwState = _this.runLstm(sequence, _this.lstmFwVars, false); | ||
var bwState = _this.runLstm(sequence, _this.lstmBwVars, true); | ||
var finalState = dl.concat2d([fwState[1], bwState[1]], 1); | ||
if (_this.muVars) { | ||
return dense(_this.muVars, finalState); | ||
} | ||
else { | ||
return finalState; | ||
} | ||
var mu = dense(_this.muVars, finalState); | ||
return mu; | ||
}); | ||
}; | ||
BidirectonalLstmEncoder.prototype.singleDirection = function (inputs, fw) { | ||
Encoder.prototype.runLstm = function (inputs, lstmVars, reverse) { | ||
var batchSize = inputs.shape[0]; | ||
var length = inputs.shape[1]; | ||
var outputSize = inputs.shape[2]; | ||
var lstmVars = fw ? this.lstmFwVars : this.lstmBwVars; | ||
var state = [ | ||
@@ -104,2 +112,3 @@ dl.zeros([batchSize, lstmVars.bias.shape[0] / 4]), | ||
]; | ||
var forgetBias = dl.scalar(1.0); | ||
var lstm = function (data, state) { | ||
@@ -109,3 +118,3 @@ return dl.basicLSTMCell(forgetBias, lstmVars.kernel, lstmVars.bias, data, state[0], state[1]); | ||
for (var i = 0; i < length; i++) { | ||
var index = fw ? i : length - 1 - i; | ||
var index = reverse ? length - 1 - i : i; | ||
state = lstm(inputs.slice([0, index, 0], [batchSize, 1, outputSize]).as2D(batchSize, outputSize), state); | ||
@@ -115,88 +124,45 @@ } | ||
}; | ||
return BidirectonalLstmEncoder; | ||
}(Encoder)); | ||
var HierarhicalEncoder = (function (_super) { | ||
__extends(HierarhicalEncoder, _super); | ||
function HierarhicalEncoder(baseEncoders, numSteps, muVars) { | ||
var _this = _super.call(this) || this; | ||
_this.baseEncoders = baseEncoders; | ||
_this.numSteps = numSteps; | ||
_this.muVars = muVars; | ||
_this.zDims = _this.muVars.bias.shape[0]; | ||
return _this; | ||
} | ||
HierarhicalEncoder.prototype.encode = function (sequence) { | ||
var _this = this; | ||
return dl.tidy(function () { | ||
var batchSize = sequence.shape[0]; | ||
var inputs = sequence; | ||
for (var level = 0; level < _this.baseEncoders.length; ++level) { | ||
var levelSteps = _this.numSteps[level]; | ||
var stepSize = inputs.shape[1] / levelSteps; | ||
var depth = inputs.shape[2]; | ||
var embeddings = []; | ||
for (var step = 0; step < levelSteps; ++step) { | ||
embeddings.push(_this.baseEncoders[level].encode(inputs.slice([0, step * stepSize, 0], [batchSize, stepSize, depth]))); | ||
} | ||
inputs = (embeddings.length > 1) ? | ||
dl.stack(embeddings, 1) : | ||
embeddings[0].expandDims(1); | ||
} | ||
return dense(_this.muVars, inputs.squeeze([1])); | ||
}); | ||
}; | ||
return HierarhicalEncoder; | ||
}(Encoder)); | ||
function initLstmCells(z, lstmCellVars, zToInitStateVars) { | ||
var batchSize = z.shape[0]; | ||
var lstmCells = []; | ||
var c = []; | ||
var h = []; | ||
var initialStates = dense(zToInitStateVars, z).tanh(); | ||
var stateOffset = 0; | ||
var _loop_1 = function (i) { | ||
var lv = lstmCellVars[i]; | ||
var stateWidth = lv.bias.shape[0] / 4; | ||
lstmCells.push(function (data, c, h) { | ||
return dl.basicLSTMCell(forgetBias, lv.kernel, lv.bias, data, c, h); | ||
}); | ||
c.push(initialStates.slice([0, stateOffset], [batchSize, stateWidth])); | ||
stateOffset += stateWidth; | ||
h.push(initialStates.slice([0, stateOffset], [batchSize, stateWidth])); | ||
stateOffset += stateWidth; | ||
}; | ||
for (var i = 0; i < lstmCellVars.length; ++i) { | ||
_loop_1(i); | ||
} | ||
return { 'cell': lstmCells, 'c': c, 'h': h }; | ||
} | ||
return Encoder; | ||
}()); | ||
exports.Encoder = Encoder; | ||
var Decoder = (function () { | ||
function Decoder() { | ||
function Decoder(lstmCellVars, zToInitStateVars, outputProjectVars, nade) { | ||
this.lstmCellVars = lstmCellVars; | ||
this.zToInitStateVars = zToInitStateVars; | ||
this.outputProjectVars = outputProjectVars; | ||
this.zDims = this.zToInitStateVars.kernel.shape[0]; | ||
this.outputDims = (nade) ? nade.numDims : outputProjectVars.bias.shape[0]; | ||
this.nade = nade; | ||
} | ||
return Decoder; | ||
}()); | ||
exports.Decoder = Decoder; | ||
var BaseDecoder = (function (_super) { | ||
__extends(BaseDecoder, _super); | ||
function BaseDecoder(lstmCellVars, zToInitStateVars, outputProjectVars, nade) { | ||
var _this = _super.call(this) || this; | ||
_this.lstmCellVars = lstmCellVars; | ||
_this.zToInitStateVars = zToInitStateVars; | ||
_this.outputProjectVars = outputProjectVars; | ||
_this.zDims = _this.zToInitStateVars.kernel.shape[0]; | ||
_this.outputDims = (nade) ? nade.numDims : outputProjectVars.bias.shape[0]; | ||
_this.nade = nade; | ||
return _this; | ||
} | ||
BaseDecoder.prototype.decode = function (z, length, initialInput, temperature) { | ||
Decoder.prototype.decode = function (z, length, temperature) { | ||
var _this = this; | ||
var batchSize = z.shape[0]; | ||
return dl.tidy(function () { | ||
var lstmCell = initLstmCells(z, _this.lstmCellVars, _this.zToInitStateVars); | ||
var lstmCells = []; | ||
var c = []; | ||
var h = []; | ||
var initialStates = dense(_this.zToInitStateVars, z).tanh(); | ||
var stateOffset = 0; | ||
var _loop_1 = function (i) { | ||
var lv = _this.lstmCellVars[i]; | ||
var stateWidth = lv.bias.shape[0] / 4; | ||
var forgetBias = dl.scalar(1.0); | ||
lstmCells.push(function (data, c, h) { | ||
return dl.basicLSTMCell(forgetBias, lv.kernel, lv.bias, data, c, h); | ||
}); | ||
c.push(initialStates.slice([0, stateOffset], [batchSize, stateWidth])); | ||
stateOffset += stateWidth; | ||
h.push(initialStates.slice([0, stateOffset], [batchSize, stateWidth])); | ||
stateOffset += stateWidth; | ||
}; | ||
for (var i = 0; i < _this.lstmCellVars.length; ++i) { | ||
_loop_1(i); | ||
} | ||
var samples = []; | ||
var nextInput = initialInput ? | ||
initialInput : dl.zeros([batchSize, _this.outputDims]); | ||
var nextInput = dl.zeros([batchSize, _this.outputDims]); | ||
for (var i = 0; i < length; ++i) { | ||
_a = dl.multiRNNCell(lstmCell.cell, dl.concat2d([nextInput, z], 1), lstmCell.c, lstmCell.h), lstmCell.c = _a[0], lstmCell.h = _a[1]; | ||
var logits = dense(_this.outputProjectVars, lstmCell.h[lstmCell.h.length - 1]); | ||
var output = dl.multiRNNCell(lstmCells, dl.concat2d([nextInput, z], 1), c, h); | ||
c = output[0]; | ||
h = output[1]; | ||
var logits = dense(_this.outputProjectVars, h[h.length - 1]); | ||
var timeSamples = void 0; | ||
@@ -219,71 +185,7 @@ if (_this.nade == null) { | ||
return dl.stack(samples, 1); | ||
var _a; | ||
}); | ||
}; | ||
return BaseDecoder; | ||
}(Decoder)); | ||
var ConductorDecoder = (function (_super) { | ||
__extends(ConductorDecoder, _super); | ||
function ConductorDecoder(coreDecoder, lstmCellVars, zToInitStateVars, numSteps) { | ||
var _this = _super.call(this) || this; | ||
_this.coreDecoder = coreDecoder; | ||
_this.lstmCellVars = lstmCellVars; | ||
_this.zToInitStateVars = zToInitStateVars; | ||
_this.numSteps = numSteps; | ||
_this.zDims = _this.zToInitStateVars.kernel.shape[0]; | ||
_this.outputDims = _this.coreDecoder.outputDims; | ||
return _this; | ||
} | ||
ConductorDecoder.prototype.decode = function (z, length, initialInput, temperature) { | ||
var _this = this; | ||
var batchSize = z.shape[0]; | ||
return dl.tidy(function () { | ||
var lstmCell = initLstmCells(z, _this.lstmCellVars, _this.zToInitStateVars); | ||
var samples = []; | ||
var dummyInput = dl.zeros([batchSize, 1]); | ||
for (var i = 0; i < _this.numSteps; ++i) { | ||
_a = dl.multiRNNCell(lstmCell.cell, dummyInput, lstmCell.c, lstmCell.h), lstmCell.c = _a[0], lstmCell.h = _a[1]; | ||
var initialInput_1 = samples.length ? | ||
samples[samples.length - 1].slice([0, -1, 0], [batchSize, 1, _this.outputDims]).as2D(batchSize, -1).toFloat() : | ||
undefined; | ||
samples.push(_this.coreDecoder.decode(lstmCell.h[lstmCell.h.length - 1], length / _this.numSteps, initialInput_1, temperature)); | ||
} | ||
return dl.concat(samples, 1); | ||
var _a; | ||
}); | ||
}; | ||
return ConductorDecoder; | ||
}(Decoder)); | ||
var Nade = (function () { | ||
function Nade(encWeights, decWeightsT) { | ||
this.numDims = encWeights.shape[0]; | ||
this.numHidden = encWeights.shape[2]; | ||
this.encWeights = encWeights.as2D(this.numDims, this.numHidden); | ||
this.decWeightsT = decWeightsT.as2D(this.numDims, this.numHidden); | ||
} | ||
Nade.prototype.sample = function (encBias, decBias) { | ||
var _this = this; | ||
var batchSize = encBias.shape[0]; | ||
return dl.tidy(function () { | ||
var samples = []; | ||
var a = encBias.clone(); | ||
for (var i = 0; i < _this.numDims; i++) { | ||
var h = dl.sigmoid(a); | ||
var encWeightsI = _this.encWeights.slice([i, 0], [1, _this.numHidden]).as1D(); | ||
var decWeightsTI = _this.decWeightsT.slice([i, 0], [1, _this.numHidden]); | ||
var decBiasI = decBias.slice([0, i], [batchSize, 1]); | ||
var condLogitsI = decBiasI.add(dl.matMul(h, decWeightsTI, false, true)); | ||
var condProbsI = condLogitsI.sigmoid(); | ||
var samplesI = condProbsI.greaterEqual(dl.scalar(0.5)).toFloat().as1D(); | ||
if (i < _this.numDims - 1) { | ||
a = a.add(dl.outerProduct(samplesI.toFloat(), encWeightsI)); | ||
} | ||
samples.push(samplesI); | ||
} | ||
return dl.stack(samples, 1); | ||
}); | ||
}; | ||
return Nade; | ||
return Decoder; | ||
}()); | ||
exports.Nade = Nade; | ||
exports.Decoder = Decoder; | ||
var MusicVAE = (function () { | ||
@@ -310,27 +212,8 @@ function MusicVAE(checkpointURL, dataConverter) { | ||
}; | ||
MusicVAE.prototype.getLstmLayers = function (cellFormat, vars) { | ||
var lstmLayers = []; | ||
var l = 0; | ||
while (true) { | ||
var cellPrefix = cellFormat.replace('%d', l.toString()); | ||
if (!(cellPrefix + 'kernel' in vars)) { | ||
break; | ||
} | ||
lstmLayers.push(new LayerVars(vars[cellPrefix + 'kernel'], vars[cellPrefix + 'bias'])); | ||
++l; | ||
} | ||
return lstmLayers; | ||
}; | ||
MusicVAE.prototype.initialize = function () { | ||
return __awaiter(this, void 0, void 0, function () { | ||
var LSTM_CELL_FORMAT, MUTLI_LSTM_CELL_FORMAT, CONDUCTOR_PREFIX, BIDI_LSTM_CELL, ENCODER_FORMAT, HIER_ENCODER_FORMAT, reader, vars, encMu, fwLayers_1, bwLayers_1, baseEncoders, fwLayers, bwLayers, decVarPrefix, decLstmLayers, decZtoInitState, decOutputProjection, nade, decoder, condLstmLayers, condZtoInitState; | ||
var reader, vars, encLstmFw, encLstmBw, encMu, decLstmLayers, l, cellPrefix, decZtoInitState, decOutputProjection, nade; | ||
return __generator(this, function (_a) { | ||
switch (_a.label) { | ||
case 0: | ||
LSTM_CELL_FORMAT = 'cell_%d/lstm_cell/'; | ||
MUTLI_LSTM_CELL_FORMAT = 'multi_rnn_cell/' + LSTM_CELL_FORMAT; | ||
CONDUCTOR_PREFIX = 'decoder/hierarchical_level_0/'; | ||
BIDI_LSTM_CELL = 'cell_%d/bidirectional_rnn/%s/multi_rnn_cell/cell_0/lstm_cell/'; | ||
ENCODER_FORMAT = 'encoder/' + BIDI_LSTM_CELL; | ||
HIER_ENCODER_FORMAT = 'encoder/hierarchical_level_%d/' + BIDI_LSTM_CELL.replace('%d', '0'); | ||
reader = new checkpoint_loader_1.CheckpointLoader(this.checkpointURL); | ||
@@ -340,39 +223,22 @@ return [4, reader.getAllVariables()]; | ||
vars = _a.sent(); | ||
this.rawVars = vars; | ||
encLstmFw = new LayerVars(vars['encoder/cell_0/bidirectional_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/kernel'], vars['encoder/cell_0/bidirectional_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/bias']); | ||
encLstmBw = new LayerVars(vars['encoder/cell_0/bidirectional_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/kernel'], vars['encoder/cell_0/bidirectional_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/bias']); | ||
encMu = new LayerVars(vars['encoder/mu/kernel'], vars['encoder/mu/bias']); | ||
if (this.dataConverter.numSegments) { | ||
fwLayers_1 = this.getLstmLayers(HIER_ENCODER_FORMAT.replace('%s', 'fw'), vars); | ||
bwLayers_1 = this.getLstmLayers(HIER_ENCODER_FORMAT.replace('%s', 'bw'), vars); | ||
if (fwLayers_1.length !== bwLayers_1.length || fwLayers_1.length !== 2) { | ||
throw Error('Only 2 hierarchical encoder levels are supported. ' + | ||
'Got ' + fwLayers_1.length + ' forward and ' + | ||
bwLayers_1.length + ' backward.'); | ||
decLstmLayers = []; | ||
l = 0; | ||
while (true) { | ||
cellPrefix = DECODER_CELL_FORMAT.replace('%d', l.toString()); | ||
if (!(cellPrefix + 'kernel' in vars)) { | ||
break; | ||
} | ||
baseEncoders = [0, 1].map(function (l) { return new BidirectonalLstmEncoder(fwLayers_1[l], bwLayers_1[l]); }); | ||
this.encoder = new HierarhicalEncoder(baseEncoders, [this.dataConverter.numSegments, 1], encMu); | ||
decLstmLayers.push(new LayerVars(vars[cellPrefix + 'kernel'], vars[cellPrefix + 'bias'])); | ||
++l; | ||
} | ||
else { | ||
fwLayers = this.getLstmLayers(ENCODER_FORMAT.replace('%s', 'fw'), vars); | ||
bwLayers = this.getLstmLayers(ENCODER_FORMAT.replace('%s', 'bw'), vars); | ||
if (fwLayers.length !== bwLayers.length || fwLayers.length !== 1) { | ||
throw Error('Only single-layer bidirectional encoders are supported. ' + | ||
'Got ' + fwLayers.length + ' forward and ' + | ||
bwLayers.length + ' backward.'); | ||
} | ||
this.encoder = new BidirectonalLstmEncoder(fwLayers[0], bwLayers[0], encMu); | ||
} | ||
decVarPrefix = (this.dataConverter.numSegments) ? | ||
'core_decoder/decoder/' : 'decoder/'; | ||
decLstmLayers = this.getLstmLayers(decVarPrefix + MUTLI_LSTM_CELL_FORMAT, vars); | ||
decZtoInitState = new LayerVars(vars[decVarPrefix + 'z_to_initial_state/kernel'], vars[decVarPrefix + 'z_to_initial_state/bias']); | ||
decOutputProjection = new LayerVars(vars[decVarPrefix + 'output_projection/kernel'], vars[decVarPrefix + 'output_projection/bias']); | ||
nade = ((decVarPrefix + 'nade/w_enc' in vars) ? | ||
new Nade(vars[decVarPrefix + 'nade/w_enc'], vars[decVarPrefix + 'nade/w_dec_t']) : null); | ||
decoder = new BaseDecoder(decLstmLayers, decZtoInitState, decOutputProjection, nade); | ||
if (this.dataConverter.numSegments) { | ||
condLstmLayers = this.getLstmLayers(CONDUCTOR_PREFIX + LSTM_CELL_FORMAT, vars); | ||
condZtoInitState = new LayerVars(vars[CONDUCTOR_PREFIX + 'initial_state/kernel'], vars[CONDUCTOR_PREFIX + 'initial_state/bias']); | ||
decoder = new ConductorDecoder(decoder, condLstmLayers, condZtoInitState, this.dataConverter.numSegments); | ||
} | ||
this.decoder = decoder; | ||
decZtoInitState = new LayerVars(vars['decoder/z_to_initial_state/kernel'], vars['decoder/z_to_initial_state/bias']); | ||
decOutputProjection = new LayerVars(vars['decoder/output_projection/kernel'], vars['decoder/output_projection/bias']); | ||
nade = (('decoder/nade/w_enc' in vars) ? | ||
new Nade(vars['decoder/nade/w_enc'], vars['decoder/nade/w_dec_t']) : null); | ||
this.encoder = new Encoder(encLstmFw, encLstmBw, encMu); | ||
this.decoder = new Decoder(decLstmLayers, decZtoInitState, decOutputProjection, nade); | ||
this.rawVars = vars; | ||
return [2, this]; | ||
@@ -448,3 +314,3 @@ } | ||
var randZs = dl.randomNormal([numSamples, _this.decoder.zDims]); | ||
return _this.decoder.decode(randZs, numSteps, undefined, temperature); | ||
return _this.decoder.decode(randZs, numSteps, temperature); | ||
}); | ||
@@ -451,0 +317,0 @@ }; |
{ | ||
"name": "@magenta/music-vae", | ||
"version": "1.0.4", | ||
"description": "", | ||
"version": "1.0.5", | ||
"description": "A machine learning model for exploring latent spaces of musical scores", | ||
"main": "es5/index.js", | ||
@@ -6,0 +6,0 @@ "types": "es5/index.d.ts", |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Major refactor
Supply chain riskPackage has recently undergone a major refactor. It may be unstable or indicate significant internal changes. Use caution when updating to versions that include significant changes.
Found 1 instance in 1 package
1
-50%1150154
-1.04%23864
-0.58%