Socket
Socket
Sign inDemoInstall

@tensorflow/tfjs-layers

Package Overview
Dependencies
Maintainers
9
Versions
157
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

@tensorflow/tfjs-layers - npm Package Compare versions

Comparing version 4.10.0 to 4.11.0

dist/layers/nlp/modeling/transformer_layer_utils.d.ts

2

dist/base_callbacks.js

@@ -491,2 +491,2 @@ /**

}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"base_callbacks.js","sourceRoot":"","sources":["../../../../../tfjs-layers/src/base_callbacks.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,yCAAyC;AAEzC,OAAO,EAAC,GAAG,EAAE,GAAG,EAAE,IAAI,EAAE,GAAG,EAAE,SAAS,EAAkB,IAAI,EAAE,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAGjG,OAAO,EAAC,UAAU,EAAC,MAAM,UAAU,CAAC;AACpC,OAAO,EAAO,oBAAoB,EAAiB,MAAM,QAAQ,CAAC;AAClE,OAAO,KAAK,aAAa,MAAM,uBAAuB,CAAC;AAEvD,oDAAoD;AACpD,MAAM,CAAN,IAAY,qBAGX;AAHD,WAAY,qBAAqB;IAC/B,qEAAU,CAAA;IACV,uEAAW,CAAA;AACb,CAAC,EAHW,qBAAqB,KAArB,qBAAqB,QAGhC;AAED,mEAAmE;AACnE,MAAM,CAAC,MAAM,sBAAsB,GAAG,GAAG,CAAC;AAQ1C;;;;;;;;;;;;;;;;;GAiBG;AACH,MAAM,OAAgB,YAAY;IAAlC;QACE,iDAAiD;QACjD,mBAAc,GAAoB,IAAI,CAAC;IAgCzC,CAAC;IA1BC,SAAS,CAAC,MAAc;QACtB,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC;IACvB,CAAC;IAED,KAAK,CAAC,YAAY,CAAC,KAAa,EAAE,IAAqB,IAAG,CAAC;IAE3D,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB,IAAG,CAAC;IAEzD,KAAK,CAAC,YAAY,CAAC,KAAa,EAAE,IAAqB,IAAG,CAAC;IAE3D,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB,IAAG,CAAC;IAEzD,KAAK,CAAC,YAAY,CAAC,IAAqB,IAAG,CAAC;IAE5C,KAAK,CAAC,UAAU,CAAC,IAAqB,IAAG,CAAC;IAE1C,4EAA4E;IAC5E,8EAA8E;IAC9E,8EAA8E;IAC9E,0EAA0E;IAC1E,8EAA8E;IAC9E,mEAAmE;IACnE,eAAe;IACf,QAAQ,CAAC,KAAgB;QACvB,uEAAuE;IACzE,CAAC;CACF;AAED;;GAEG;AACH,MAAM,OAAO,YAAY;IAIvB,sEAAsE;IACtE,uCAAuC;IACvC,+BAA+B;IAC/B,4CAA4C;IAC5C,0CAA0C;IAE1C;;;;;OAKG;IACH,YAAY,SAA0B,EAAE,WAAW,GAAG,EAAE;QACtD,2EAA2E;QAC3E,UAAU;QACV,IAAI,SAAS,IAAI,IAAI,EAAE;YACrB,SAAS,GAAG,EAAE,CAAC;SAChB;QACD,IAAI,CAAC,SAAS,GAAG,SAAS,CAAC;QAC3B,IAAI,CAAC,WAAW,GAAG,WAAW,CAAC;IACjC,CAAC;IAED,MAAM,CAAC,QAAsB;QAC3B,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;IAChC,CAAC;IAED,SAAS,CAAC,MAAc;QACtB,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,QAAQ,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC;SAC5B;IACH,CAAC;IAED,QAAQ,CAAC,KAAgB;QACvB,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,QAAQ,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC;SAC1B;IACH,CAAC;IAED;;;;OAIG;IACH,KAAK,CAAC,YAAY,CAAC,KAAa,EAAE,IAAqB;QACrD,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,MAAM,QAAQ,CAAC,YAAY,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;SAC1C;IACH,CAAC;IAED;;;;OAIG;IACH,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB;QACnD,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,MAAM,QAAQ,CAAC,UAAU,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;SACxC;IACH,CAAC;IAED;;;;OAIG;IACH,KAAK,CAAC,YAAY,CAAC,KAAa,EAAE,IAAqB;QACrD,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,MAAM,QAAQ,CAAC,YAAY,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;SAC1C;IACH,CAAC;IAED;;;;OAIG;IACH,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB;QACnD,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,MAAM,QAAQ,CAAC,UAAU,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;SACxC;IACH,CAAC;IAED;;;OAGG;IACH,KAAK,CAAC,YAAY,CAAC,IAAqB;QACtC,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,MAAM,QAAQ,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;SACnC;IACH,CAAC;IAED;;;OAGG;IACH,KAAK,CAAC,UAAU,CAAC,IAAqB;QACpC,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,MAAM,QAAQ,CAAC,UAAU,CAAC,IAAI,CAAC,CAAC;SACjC;IACH,CAAC;CACF;AAED;;;;GAIG;AACH,MAAM,OAAO,UAAW,SAAQ,YAAY;IAI1C;QACE,KAAK,EAAE,CAAC;IACV,CAAC;IAEQ,KAAK,CAAC,YAAY,CAAC,KAAa;QACvC,IAAI,CAAC,IAAI,GAAG,CAAC,CAAC;QACd,IAAI,CAAC,MAAM,GAAG,EAAE,CAAC;IACnB,CAAC;IAEQ,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB;QAC5D,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,MAAM,SAAS,GAAG,IAAI,CAAC,MAAM,CAAC,IAAI,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,MAAM,CAAW,CAAC;QACpE,IAAI,CAAC,IAAI,IAAI,SAAS,CAAC;QACvB,KAAK,MAAM,GAAG,IAAI,IAAI,EAAE;YACtB,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC;YACxB,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE;gBAC7B,IAAI,CAAC,IAAI,CAAC,MAAM,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE;oBACpC,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;iBACtB;gBACD,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAW,GAAG,KAAK,GAAG,SAAS,CAAC;aACnE;iBAAM;gBACL,IAAI,kBAA0B,CAAC;gBAC/B,IAAI,GAAG,IAAI,IAAI,CAAC,MAAM,EAAE;oBACtB,kBAAkB,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAW,CAAC;iBACjD;qBAAM;oBACL,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;iBACtB;gBACD,MAAM,KAAK,GACP,IAAI,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,CAAC,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC;gBAC/D,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,GAAG,KAAK,CAAC;gBACzB,IAAI,kBAAkB,IAAI,IAAI,EAAE;oBAC9B,kBAAkB,CAAC,OAAO,EAAE,CAAC;iBAC9B;aACF;SACF;IACH,CAAC;IAEQ,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB;QAC5D,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,KAAK,MAAM,GAAG,IAAI,IAAI,CAAC,MAAM,CAAC,SAAS,CAAa,EAAE;gBACpD,IAAI,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,IAAI,IAAI,EAAE;oBAC5B,SAAS;iBACV;gBACD,IAAI,OAAO,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,KAAK,QAAQ,EAAE;oBACxC,IAAI,CAAC,GAAG,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAW,GAAG,IAAI,CAAC,IAAI,CAAC;iBACpD;qBAAM;oBACL,IAAI,CAAC,GAAG,EAAE;wBACR,MAAM,GAAG,GAAW,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,EAAE,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC;wBAC7D,IAAI,CAAC,GAAG,CAAC,GAAG,GAAG,CAAC;wBACf,IAAI,CAAC,MAAM,CAAC,GAAG,CAAY,CAAC,OAAO,EAAE,CAAC;wBACvC,IAAI,CAAC,IAAI,CAAC,GAAG,CAAW,CAAC,CAAC;oBAC5B,CAAC,CAAC,CAAC;iBACJ;aACF;SACF;IACH,CAAC;CACF;AAED;;;;GAIG;AACH,MAAM,OAAO,OAAQ,SAAQ,YAAY;IAI9B,KAAK,CAAC,YAAY,CAAC,IAAqB;QAC/C,IAAI,CAAC,KAAK,GAAG,EAAE,CAAC;QAChB,IAAI,CAAC,OAAO,GAAG,EAAE,CAAC;IACpB,CAAC;IAEQ,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB;QAC5D,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QACvB,KAAK,MAAM,GAAG,IAAI,IAAI,EAAE;YACtB,IAAI,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,IAAI,IAAI,EAAE;gBAC7B,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC;aACxB;YACD,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC;SACnC;IACH,CAAC;IAED;;OAEG;IACH,KAAK,CAAC,QAAQ;QACZ,MAAM,QAAQ,GAAuD,EAAE,CAAC;QACxE,MAAM,IAAI,GAAa,EAAE,CAAC;QAC1B,MAAM,OAAO,GAAa,EAAE,CAAC;QAC7B,KAAK,MAAM,GAAG,IAAI,IAAI,CAAC,OAAO,EAAE;YAC9B,MAAM,UAAU,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC;YACrC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;gBAC1C,IAAI,OAAO,UAAU,CAAC,CAAC,CAAC,KAAK,QAAQ,EAAE;oBACrC,MAAM,WAAW,GAAG,UAAU,CAAC,CAAC,CAAW,CAAC;oBAC5C,QAAQ,CAAC,IAAI,CAAC,WAAW,CAAC,IAAI,EAAE,CAAC,CAAC;oBAClC,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;oBACf,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;iBACjB;aACF;SACF;QACD,MAAM,MAAM,GAAG,MAAM,OAAO,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC;QAC3C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;YACtC,MAAM,eAAe,GAAG,IAAI,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAW,CAAC;YACpE,eAAe,CAAC,OAAO,EAAE,CAAC;YAC1B,IAAI,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;SAClD;IACH,CAAC;CACF;AAeD;;GAEG;AACH,MAAM,OAAO,cAAe,SAAQ,YAAY;IAmB9C,YAAY,IAAwB,EAAE,UAA8B;QAClE,KAAK,EAAE,CAAC;QALF,iBAAY,GAAG,CAAC,CAAC;QAMvB,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,OAAO,CAAC;QAC5B,IAAI,CAAC,aAAa,GAAG,IAAI,CAAC,aAAa,IAAI,SAAS,CAAC;QACrD,IAAI,CAAC,UAAU,GAAG,UAAU,IAAI,MAAM,CAAC;QACvC,IAAI,IAAI,CAAC,UAAU,KAAK,MAAM,EAAE;YAC9B,IAAI,CAAC,UAAU,GAAG,sBAAsB,CAAC;SAC1C;QACD,IAAI,IAAI,CAAC,UAAU,KAAK,OAAO,IAAI,IAAI,CAAC,OAAO,IAAI,IAAI,EAAE;YACvD,MAAM,IAAI,KAAK,CACX,gEAAgE;gBAChE,mDAAmD,CAAC,CAAC;SAC1D;QACD,IAAI,IAAI,CAAC,QAAQ,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE;YAClC,+DAA+D;YAC/D,mBAAmB;YACnB,IAAI,CAAC,SAAS,GAAG,aAAa,CAAC,QAAQ,CACnC,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,IAAI,CAAC,EAAE,IAAI,CAAC,UAAoB,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC;SACzE;QACD,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,YAAY,CAAC;QACpC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,UAAU,CAAC;QAChC,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,YAAY,CAAC;QACpC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,UAAU,CAAC;QAChC,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,YAAY,CAAC;QACpC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,UAAU,CAAC;QAChC,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,OAAO,CAAC;IAC5B,CAAC;IAED,KAAK,CAAC,SAAS,CAAC,KAAa,EAAE,KAAa,EAAE,IAAoB;QAChE,MAAM,EAAE,GAA8B,EAAE,CAAC;QACzC,IAAI,IAAI,CAAC,KAAK,IAAI,IAAI,EAAE;YACtB,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;YACjC,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,EAAE,KAAK,EAAE,IAAY,CAAC,CAAC,CAAC;SACjD;QACD,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,aAAa,EAAE,CAAC,CAAC;QAC9B,MAAM,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;IACxB,CAAC;IAEQ,KAAK,CAAC,YAAY,CAAC,KAAa,EAAE,IAAqB;QAE9D,IAAI,CAAC,YAAY,GAAG,KAAK,CAAC;QAC1B,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;YACjC,MAAM,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,IAAY,CAAC,CAAC;SAC5C;IACH,CAAC;IAEQ,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB;QAE5D,MAAM,EAAE,GAA8B,EAAE,CAAC;QACzC,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;YACjC,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,CAAC,KAAK,EAAE,IAAY,CAAC,CAAC,CAAC;SAC7C;QACD,IAAI,IAAI,CAAC,UAAU,KAAK,OAAO,EAAE;YAC/B,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,aAAa,EAAE,CAAC,CAAC;SAC/B;QACD,MAAM,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;IACxB,CAAC;IAEQ,KAAK,CAAC,YAAY,CAAC,KAAa,EAAE,IAAqB;QAE9D,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;YACjC,MAAM,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,IAAY,CAAC,CAAC;SAC5C;IACH,CAAC;IAEQ,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB;QAE5D,MAAM,EAAE,GAA8B,EAAE,CAAC;QACzC,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;YACjC,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,CAAC,KAAK,EAAE,IAAY,CAAC,CAAC,CAAC;SAC7C;QACD,IAAI,IAAI,CAAC,UAAU,KAAK,OAAO,EAAE;YAC/B,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,aAAa,EAAE,CAAC,CAAC;SAC/B;aAAM,IAAI,IAAI,CAAC,QAAQ,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE;YACzC,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,YAAY,EAAE,KAAK,EAAE,IAAI,CAAC,CAAC,CAAC;SACzD;QACD,MAAM,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;IACxB,CAAC;IAEQ,KAAK,CAAC,YAAY,CAAC,IAAqB;QAC/C,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;YACjC,MAAM,IAAI,CAAC,UAAU,CAAC,IAAY,CAAC,CAAC;SACrC;IACH,CAAC;IAEQ,KAAK,CAAC,UAAU,CAAC,IAAqB;QAC7C,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;YACjC,MAAM,IAAI,CAAC,QAAQ,CAAC,IAAY,CAAC,CAAC;SACnC;IACH,CAAC;CACF;AAED;;GAEG;AACH,MAAM,UAAU,oBAAoB,CAChC,SACoB,EACpB,UAA6B;IAC/B,IAAI,SAAS,IAAI,IAAI,EAAE;QACrB,SAAS,GAAG,EAAkB,CAAC;KAChC;IACD,IAAI,SAAS,YAAY,YAAY,EAAE;QACrC,OAAO,CAAC,SAAS,CAAC,CAAC;KACpB;IACD,IAAI,KAAK,CAAC,OAAO,CAAC,SAAS,CAAC,IAAI,SAAS,CAAC,CAAC,CAAC,YAAY,YAAY,EAAE;QACpE,OAAO,SAA2B,CAAC;KACpC;IACD,8DAA8D;IAC9D,MAAM,eAAe,GACjB,aAAa,CAAC,MAAM,CAAC,SAAS,CAAyB,CAAC;IAC5D,OAAO,eAAe,CAAC,GAAG,CACtB,cAAc,CAAC,EAAE,CAAC,IAAI,cAAc,CAAC,cAAc,EAAE,UAAU,CAAC,CAAC,CAAC;AACxE,CAAC;AAMD;;;GAGG;AACH,MAAa,2BAA2B;IAItC;;OAEG;IACH,gBAAuB,CAAC;IAExB;;;;;;;;;;;OAWG;IACH,MAAM,CAAC,2BAA2B,CAC9B,cAAsB,EAAE,mBAA4C;QACtE,IAAI,CAAC,MAAM,CACP,cAAc,IAAI,CAAC,IAAI,MAAM,CAAC,SAAS,CAAC,cAAc,CAAC,EACvD,GAAG,EAAE,CAAC,qDAAqD;YACvD,WAAW,cAAc,EAAE,CAAC,CAAC;QACrC,2BAA2B,CAAC,iBAAiB,CAAC,mBAAmB,CAAC,CAAC;QACnE,IAAI,2BAA2B,CAAC,YAAY,CAAC,cAAc,CAAC,IAAI,IAAI,EAAE;YACpE,2BAA2B,CAAC,YAAY,CAAC,cAAc,CAAC,GAAG,EAAE,CAAC;SAC/D;QACD,2BAA2B,CAAC,YAAY,CAAC,cAAc,CAAC,CAAC,IAAI,CACzD,mBAAmB,CAAC,CAAC;IAC3B,CAAC;IAEO,MAAM,CAAC,iBAAiB,CAAC,mBAC2B;QAC1D,KAAK,MAAM,SAAS,IAAI,2BAA2B,CAAC,YAAY,EAAE;YAChE,MAAM,YAAY,GAAG,2BAA2B,CAAC,YAAY,CAAC,CAAC,SAAS,CAAC,CAAC;YAC1E,YAAY,CAAC,OAAO,CAAC,IAAI,CAAC,EAAE;gBAC1B,IAAI,IAAI,KAAK,mBAAmB,EAAE;oBAChC,MAAM,IAAI,UAAU,CAAC,iCAAiC,CAAC,CAAC;iBACzD;YACH,CAAC,CAAC,CAAC;SACJ;IACH,CAAC;IAED;;OAEG;IACO,MAAM,CAAC,KAAK;QACpB,2BAA2B,CAAC,YAAY,GAAG,EAAE,CAAC;IAChD,CAAC;IAED;;;;;;;OAOG;IACH,MAAM,CAAC,eAAe,CAAC,cAAsB;QAC3C,MAAM,YAAY,GAA8B,EAAE,CAAC;QACnD,KAAK,MAAM,SAAS,IAAI,2BAA2B,CAAC,YAAY,EAAE;YAChE,MAAM,KAAK,GAAG,CAAC,SAAS,CAAC;YACzB,IAAI,cAAc,IAAI,KAAK,EAAE;gBAC3B,YAAY,CAAC,IAAI,CAAC,GAAG,2BAA2B,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC,CAAC;aACvE;SACF;QACD,OAAO,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,IAAI,IAAI,EAAE,CAAC,CAAC;IAC9C,CAAC;;AAtEc,wCAAY,GACiC,EAAE,CAAC;SAFpD,2BAA2B;AA0ExC,MAAM,UAAU,kBAAkB,CAC9B,SAAyB,EAAE,OAA8B,EAAE,MAAc,EACzE,YAAoB,EAAE,eAAuB,EAAE,aAAqB,EACpE,SAAiB,EAAE,YAAqB,EACxC,eAAyB;IAC3B,MAAM,OAAO,GAAG,IAAI,OAAO,EAAE,CAAC;IAC9B,MAAM,eAAe,GAAmB;QACtC,IAAI,UAAU,EAAE,EAAE,GAAG,2BAA2B,CAAC,eAAe,CAAC,OAAO,CAAC;KAC1E,CAAC;IACF,IAAI,SAAS,IAAI,IAAI,EAAE;QACrB,eAAe,CAAC,IAAI,CAAC,GAAG,SAAS,CAAC,CAAC;KACpC;IACD,eAAe,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;IAC9B,MAAM,YAAY,GAAG,IAAI,YAAY,CAAC,eAAe,CAAC,CAAC;IAEvD,mEAAmE;IACnE,cAAc;IACd,wDAAwD;IAExD,YAAY,CAAC,SAAS,CAAC;QACrB,MAAM;QACN,YAAY;QACZ,OAAO,EAAE,eAAe;QACxB,KAAK,EAAE,aAAa;QACpB,SAAS;QACT,OAAO;QACP,YAAY;QACZ,OAAO,EAAE,eAAe;KACzB,CAAC,CAAC;IACH,OAAO,EAAC,YAAY,EAAE,OAAO,EAAC,CAAC;AACjC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/* Original source: keras/callbacks.py */\n\nimport {add, div, keep, mul, nextFrame, Scalar, Tensor, tidy, util} from '@tensorflow/tfjs-core';\n\nimport {Container} from './engine/container';\nimport {ValueError} from './errors';\nimport {Logs, resolveScalarsInLogs, UnresolvedLogs} from './logs';\nimport * as generic_utils from './utils/generic_utils';\n\n/** Verbosity logging level when fitting a model. */\nexport enum ModelLoggingVerbosity {\n  SILENT = 0,\n  VERBOSE = 1\n}\n\n/** How often to yield to the main thread when training (in ms). */\nexport const DEFAULT_YIELD_EVERY_MS = 125;\n\nexport type Params = {\n  [key: string]: number|string|boolean|number[]|string[]|boolean[];\n};\n\nexport type YieldEveryOptions = 'auto'|'batch'|'epoch'|'never'|number;\n\n/**\n * Abstract base class used to build new callbacks.\n *\n * The `logs` dictionary that callback methods take as argument will contain\n * keys for quantities relevant to the current batch or epoch.\n *\n * Currently, the `.fit()` method of the `Sequential` model class\n * will include the following quantities in the `logs` that\n * it passes to its callbacks:\n *\n * onEpochEnd: Logs include `acc` and `loss`, and optionally include `valLoss`\n *   (if validation is enabled in `fit`), and `valAcc` (if validation and\n *   accuracy monitoring are enabled).\n * onBatchBegin: Logs include `size`, the number of samples in the current\n *   batch.\n * onBatchEnd: Logs include `loss`, and optionally `acc` (if accuracy monitoring\n *   is enabled).\n */\nexport abstract class BaseCallback {\n  // TODO(michaelterry): This type is a best guess.\n  validationData: Tensor|Tensor[] = null;\n  /**\n   * Training parameters (eg. verbosity, batch size, number of epochs...).\n   */\n  params: Params;\n\n  setParams(params: Params): void {\n    this.params = params;\n  }\n\n  async onEpochBegin(epoch: number, logs?: UnresolvedLogs) {}\n\n  async onEpochEnd(epoch: number, logs?: UnresolvedLogs) {}\n\n  async onBatchBegin(batch: number, logs?: UnresolvedLogs) {}\n\n  async onBatchEnd(batch: number, logs?: UnresolvedLogs) {}\n\n  async onTrainBegin(logs?: UnresolvedLogs) {}\n\n  async onTrainEnd(logs?: UnresolvedLogs) {}\n\n  // LayersModel needs to call Callback.setModel(), but cannot actually depend\n  // on Callback because that creates a cyclic dependency.  Providing this no-op\n  // method on BaseCallback breaks the cycle: this way LayersModel can depend on\n  // BaseCallback but not on Callback.  The argument is typed as `Container`\n  // (the superclass of LayersModel) to avoid recapitulating the cycle. Callback\n  // overrides this method and enforces that the argument is really a\n  // LayersModel.\n  setModel(model: Container): void {\n    // Do nothing. Use Callback instead of BaseCallback to track the model.\n  }\n}\n\n/**\n * Container abstracting a list of callbacks.\n */\nexport class CallbackList {\n  callbacks: BaseCallback[];\n  queueLength: number;\n\n  // TODO(cais): When the need arises, uncomment the following lines and\n  // implement the queue for time values.\n  // private deltaTBatch: number;\n  // private deltaTsBatchBegin: Array<number>;\n  // private deltaTsBatchEnd: Array<number>;\n\n  /**\n   * Constructor of CallbackList.\n   * @param callbacks Array of `Callback` instances.\n   * @param queueLength Queue length for keeping running statistics over\n   *   callback execution time.\n   */\n  constructor(callbacks?: BaseCallback[], queueLength = 10) {\n    // TODO(cais): Make use of queueLength when implementing the queue for time\n    // values.\n    if (callbacks == null) {\n      callbacks = [];\n    }\n    this.callbacks = callbacks;\n    this.queueLength = queueLength;\n  }\n\n  append(callback: BaseCallback): void {\n    this.callbacks.push(callback);\n  }\n\n  setParams(params: Params): void {\n    for (const callback of this.callbacks) {\n      callback.setParams(params);\n    }\n  }\n\n  setModel(model: Container): void {\n    for (const callback of this.callbacks) {\n      callback.setModel(model);\n    }\n  }\n\n  /**\n   * Called at the start of an epoch.\n   * @param epoch Index of epoch.\n   * @param logs Dictionary of logs.\n   */\n  async onEpochBegin(epoch: number, logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    for (const callback of this.callbacks) {\n      await callback.onEpochBegin(epoch, logs);\n    }\n  }\n\n  /**\n   * Called at the end of an epoch.\n   * @param epoch Index of epoch.\n   * @param logs Dictionary of logs.\n   */\n  async onEpochEnd(epoch: number, logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    for (const callback of this.callbacks) {\n      await callback.onEpochEnd(epoch, logs);\n    }\n  }\n\n  /**\n   * Called  right before processing a batch.\n   * @param batch Index of batch within the current epoch.\n   * @param logs Dictionary of logs.\n   */\n  async onBatchBegin(batch: number, logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    for (const callback of this.callbacks) {\n      await callback.onBatchBegin(batch, logs);\n    }\n  }\n\n  /**\n   * Called at the end of a batch.\n   * @param batch Index of batch within the current epoch.\n   * @param logs Dictionary of logs.\n   */\n  async onBatchEnd(batch: number, logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    for (const callback of this.callbacks) {\n      await callback.onBatchEnd(batch, logs);\n    }\n  }\n\n  /**\n   * Called at the beginning of training.\n   * @param logs Dictionary of logs.\n   */\n  async onTrainBegin(logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    for (const callback of this.callbacks) {\n      await callback.onTrainBegin(logs);\n    }\n  }\n\n  /**\n   * Called at the end of training.\n   * @param logs Dictionary of logs.\n   */\n  async onTrainEnd(logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    for (const callback of this.callbacks) {\n      await callback.onTrainEnd(logs);\n    }\n  }\n}\n\n/**\n * Callback that accumulates epoch averages of metrics.\n *\n * This callback is automatically applied to every LayersModel.\n */\nexport class BaseLogger extends BaseCallback {\n  private seen: number;\n  private totals: UnresolvedLogs;\n\n  constructor() {\n    super();\n  }\n\n  override async onEpochBegin(epoch: number) {\n    this.seen = 0;\n    this.totals = {};\n  }\n\n  override async onBatchEnd(batch: number, logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    const batchSize = logs['size'] == null ? 0 : logs['size'] as number;\n    this.seen += batchSize;\n    for (const key in logs) {\n      const value = logs[key];\n      if (typeof value === 'number') {\n        if (!this.totals.hasOwnProperty(key)) {\n          this.totals[key] = 0;\n        }\n        this.totals[key] = this.totals[key] as number + value * batchSize;\n      } else {\n        let oldTotalsToDispose: Scalar;\n        if (key in this.totals) {\n          oldTotalsToDispose = this.totals[key] as Scalar;\n        } else {\n          this.totals[key] = 0;\n        }\n        const total: Scalar =\n            tidy(() => add((this.totals[key]), mul(value, batchSize)));\n        this.totals[key] = total;\n        if (oldTotalsToDispose != null) {\n          oldTotalsToDispose.dispose();\n        }\n      }\n    }\n  }\n\n  override async onEpochEnd(epoch: number, logs?: UnresolvedLogs) {\n    if (logs != null) {\n      for (const key of this.params['metrics'] as string[]) {\n        if (this.totals[key] == null) {\n          continue;\n        }\n        if (typeof this.totals[key] === 'number') {\n          logs[key] = this.totals[key] as number / this.seen;\n        } else {\n          tidy(() => {\n            const log: Scalar = mul(div(1, this.seen), this.totals[key]);\n            logs[key] = log;\n            (this.totals[key] as Tensor).dispose();\n            keep(logs[key] as Scalar);\n          });\n        }\n      }\n    }\n  }\n}\n\n/**\n * Callback that records events into a `History` object. This callback is\n * automatically applied to every TF.js Layers model. The `History` object\n * gets returned by the `fit` method of models.\n */\nexport class History extends BaseCallback {\n  epoch: number[];\n  history: {[key: string]: Array<number|Tensor>};\n\n  override async onTrainBegin(logs?: UnresolvedLogs) {\n    this.epoch = [];\n    this.history = {};\n  }\n\n  override async onEpochEnd(epoch: number, logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    this.epoch.push(epoch);\n    for (const key in logs) {\n      if (this.history[key] == null) {\n        this.history[key] = [];\n      }\n      this.history[key].push(logs[key]);\n    }\n  }\n\n  /**\n   * Await the values of all losses and metrics.\n   */\n  async syncData() {\n    const promises: Array<Promise<Float32Array|Int32Array|Uint8Array>> = [];\n    const keys: string[] = [];\n    const indices: number[] = [];\n    for (const key in this.history) {\n      const valueArray = this.history[key];\n      for (let i = 0; i < valueArray.length; ++i) {\n        if (typeof valueArray[i] !== 'number') {\n          const valueScalar = valueArray[i] as Tensor;\n          promises.push(valueScalar.data());\n          keys.push(key);\n          indices.push(i);\n        }\n      }\n    }\n    const values = await Promise.all(promises);\n    for (let n = 0; n < values.length; ++n) {\n      const tensorToDispose = this.history[keys[n]][indices[n]] as Tensor;\n      tensorToDispose.dispose();\n      this.history[keys[n]][indices[n]] = values[n][0];\n    }\n  }\n}\n\nexport interface CustomCallbackArgs {\n  onTrainBegin?: (logs?: Logs) => void | Promise<void>;\n  onTrainEnd?: (logs?: Logs) => void | Promise<void>;\n  onEpochBegin?: (epoch: number, logs?: Logs) => void | Promise<void>;\n  onEpochEnd?: (epoch: number, logs?: Logs) => void | Promise<void>;\n  onBatchBegin?: (batch: number, logs?: Logs) => void | Promise<void>;\n  onBatchEnd?: (batch: number, logs?: Logs) => void | Promise<void>;\n  onYield?: (epoch: number, batch: number, logs: Logs) => void | Promise<void>;\n  // Used for test DI mocking.\n  nowFunc?: Function;\n  nextFrameFunc?: Function;\n}\n\n/**\n * Custom callback for training.\n */\nexport class CustomCallback extends BaseCallback {\n  protected readonly trainBegin: (logs?: Logs) => void | Promise<void>;\n  protected readonly trainEnd: (logs?: Logs) => void | Promise<void>;\n  protected readonly epochBegin:\n      (epoch: number, logs?: Logs) => void | Promise<void>;\n  protected readonly epochEnd:\n      (epoch: number, logs?: Logs) => void | Promise<void>;\n  protected readonly batchBegin:\n      (batch: number, logs?: Logs) => void | Promise<void>;\n  protected readonly batchEnd:\n      (batch: number, logs?: Logs) => void | Promise<void>;\n  protected readonly yield:\n      (epoch: number, batch: number, logs: Logs) => void | Promise<void>;\n\n  private yieldEvery: YieldEveryOptions;\n  private currentEpoch = 0;\n  public nowFunc: Function;\n  public nextFrameFunc: Function;\n\n  constructor(args: CustomCallbackArgs, yieldEvery?: YieldEveryOptions) {\n    super();\n    this.nowFunc = args.nowFunc;\n    this.nextFrameFunc = args.nextFrameFunc || nextFrame;\n    this.yieldEvery = yieldEvery || 'auto';\n    if (this.yieldEvery === 'auto') {\n      this.yieldEvery = DEFAULT_YIELD_EVERY_MS;\n    }\n    if (this.yieldEvery === 'never' && args.onYield != null) {\n      throw new Error(\n          'yieldEvery is `never` but you provided an `onYield` callback. ' +\n          'Either change `yieldEvery` or remove the callback');\n    }\n    if (util.isNumber(this.yieldEvery)) {\n      // Decorate `maybeWait` so it will be called at most once every\n      // `yieldEvery` ms.\n      this.maybeWait = generic_utils.debounce(\n          this.maybeWait.bind(this), this.yieldEvery as number, this.nowFunc);\n    }\n    this.trainBegin = args.onTrainBegin;\n    this.trainEnd = args.onTrainEnd;\n    this.epochBegin = args.onEpochBegin;\n    this.epochEnd = args.onEpochEnd;\n    this.batchBegin = args.onBatchBegin;\n    this.batchEnd = args.onBatchEnd;\n    this.yield = args.onYield;\n  }\n\n  async maybeWait(epoch: number, batch: number, logs: UnresolvedLogs) {\n    const ps: Array<void|Promise<void>> = [];\n    if (this.yield != null) {\n      await resolveScalarsInLogs(logs);\n      ps.push(this.yield(epoch, batch, logs as Logs));\n    }\n    ps.push(this.nextFrameFunc());\n    await Promise.all(ps);\n  }\n\n  override async onEpochBegin(epoch: number, logs?: UnresolvedLogs):\n      Promise<void> {\n    this.currentEpoch = epoch;\n    if (this.epochBegin != null) {\n      await resolveScalarsInLogs(logs);\n      await this.epochBegin(epoch, logs as Logs);\n    }\n  }\n\n  override async onEpochEnd(epoch: number, logs?: UnresolvedLogs):\n      Promise<void> {\n    const ps: Array<void|Promise<void>> = [];\n    if (this.epochEnd != null) {\n      await resolveScalarsInLogs(logs);\n      ps.push(this.epochEnd(epoch, logs as Logs));\n    }\n    if (this.yieldEvery === 'epoch') {\n      ps.push(this.nextFrameFunc());\n    }\n    await Promise.all(ps);\n  }\n\n  override async onBatchBegin(batch: number, logs?: UnresolvedLogs):\n      Promise<void> {\n    if (this.batchBegin != null) {\n      await resolveScalarsInLogs(logs);\n      await this.batchBegin(batch, logs as Logs);\n    }\n  }\n\n  override async onBatchEnd(batch: number, logs?: UnresolvedLogs):\n      Promise<void> {\n    const ps: Array<void|Promise<void>> = [];\n    if (this.batchEnd != null) {\n      await resolveScalarsInLogs(logs);\n      ps.push(this.batchEnd(batch, logs as Logs));\n    }\n    if (this.yieldEvery === 'batch') {\n      ps.push(this.nextFrameFunc());\n    } else if (util.isNumber(this.yieldEvery)) {\n      ps.push(this.maybeWait(this.currentEpoch, batch, logs));\n    }\n    await Promise.all(ps);\n  }\n\n  override async onTrainBegin(logs?: UnresolvedLogs): Promise<void> {\n    if (this.trainBegin != null) {\n      await resolveScalarsInLogs(logs);\n      await this.trainBegin(logs as Logs);\n    }\n  }\n\n  override async onTrainEnd(logs?: UnresolvedLogs): Promise<void> {\n    if (this.trainEnd != null) {\n      await resolveScalarsInLogs(logs);\n      await this.trainEnd(logs as Logs);\n    }\n  }\n}\n\n/**\n * Standardize callbacks or configurations of them to an Array of callbacks.\n */\nexport function standardizeCallbacks(\n    callbacks: BaseCallback|BaseCallback[]|CustomCallbackArgs|\n    CustomCallbackArgs[],\n    yieldEvery: YieldEveryOptions): BaseCallback[] {\n  if (callbacks == null) {\n    callbacks = {} as BaseCallback;\n  }\n  if (callbacks instanceof BaseCallback) {\n    return [callbacks];\n  }\n  if (Array.isArray(callbacks) && callbacks[0] instanceof BaseCallback) {\n    return callbacks as BaseCallback[];\n  }\n  // Convert custom callback configs to custom callback objects.\n  const callbackConfigs =\n      generic_utils.toList(callbacks) as CustomCallbackArgs[];\n  return callbackConfigs.map(\n      callbackConfig => new CustomCallback(callbackConfig, yieldEvery));\n}\n\nexport declare type BaseCallbackConstructor = {\n  new (): BaseCallback\n};\n\n/**\n * A global registry for callback constructors to be used during\n * LayersModel.fit().\n */\nexport class CallbackConstructorRegistry {\n  private static constructors:\n      {[verbosityLevel: number]: BaseCallbackConstructor[]} = {};\n\n  /**\n   * Blocks public access to constructor.\n   */\n  private constructor() {}\n\n  /**\n   * Register a tf.LayersModel.fit() callback constructor.\n   *\n   * The registered callback constructor will be used to instantiate\n   * callbacks for every tf.LayersModel.fit() call afterwards.\n   *\n   * @param verbosityLevel Level of verbosity at which the `callbackConstructor`\n   *   is to be reigstered.\n   * @param callbackConstructor A no-arg constructor for `tf.Callback`.\n   * @throws Error, if the same callbackConstructor has been registered before,\n   *   either at the same or a different `verbosityLevel`.\n   */\n  static registerCallbackConstructor(\n      verbosityLevel: number, callbackConstructor: BaseCallbackConstructor) {\n    util.assert(\n        verbosityLevel >= 0 && Number.isInteger(verbosityLevel),\n        () => `Verbosity level is expected to be an integer >= 0, ` +\n            `but got ${verbosityLevel}`);\n    CallbackConstructorRegistry.checkForDuplicate(callbackConstructor);\n    if (CallbackConstructorRegistry.constructors[verbosityLevel] == null) {\n      CallbackConstructorRegistry.constructors[verbosityLevel] = [];\n    }\n    CallbackConstructorRegistry.constructors[verbosityLevel].push(\n        callbackConstructor);\n  }\n\n  private static checkForDuplicate(callbackConstructor:\n                                       BaseCallbackConstructor) {\n    for (const levelName in CallbackConstructorRegistry.constructors) {\n      const constructors = CallbackConstructorRegistry.constructors[+levelName];\n      constructors.forEach(ctor => {\n        if (ctor === callbackConstructor) {\n          throw new ValueError('Duplicate callback constructor.');\n        }\n      });\n    }\n  }\n\n  /**\n   * Clear all registered callback constructors.\n   */\n  protected static clear() {\n    CallbackConstructorRegistry.constructors = {};\n  }\n\n  /**\n   * Create callbacks using the registered callback constructors.\n   *\n   * Given `verbosityLevel`, all constructors registered at that level or above\n   * will be called and the instantiated callbacks will be used.\n   *\n   * @param verbosityLevel: Level of verbosity.\n   */\n  static createCallbacks(verbosityLevel: number): BaseCallback[] {\n    const constructors: BaseCallbackConstructor[] = [];\n    for (const levelName in CallbackConstructorRegistry.constructors) {\n      const level = +levelName;\n      if (verbosityLevel >= level) {\n        constructors.push(...CallbackConstructorRegistry.constructors[level]);\n      }\n    }\n    return constructors.map(ctor => new ctor());\n  }\n}\n\nexport function configureCallbacks(\n    callbacks: BaseCallback[], verbose: ModelLoggingVerbosity, epochs: number,\n    initialEpoch: number, numTrainSamples: number, stepsPerEpoch: number,\n    batchSize: number, doValidation: boolean,\n    callbackMetrics: string[]): {callbackList: CallbackList, history: History} {\n  const history = new History();\n  const actualCallbacks: BaseCallback[] = [\n    new BaseLogger(), ...CallbackConstructorRegistry.createCallbacks(verbose)\n  ];\n  if (callbacks != null) {\n    actualCallbacks.push(...callbacks);\n  }\n  actualCallbacks.push(history);\n  const callbackList = new CallbackList(actualCallbacks);\n\n  // TODO(cais): Figure out when this LayersModel instance can have a\n  // dynamically\n  //   set property called 'callback_model' as in PyKeras.\n\n  callbackList.setParams({\n    epochs,\n    initialEpoch,\n    samples: numTrainSamples,\n    steps: stepsPerEpoch,\n    batchSize,\n    verbose,\n    doValidation,\n    metrics: callbackMetrics,\n  });\n  return {callbackList, history};\n}\n"]}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"base_callbacks.js","sourceRoot":"","sources":["../../../../../tfjs-layers/src/base_callbacks.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,yCAAyC;AAEzC,OAAO,EAAC,GAAG,EAAE,GAAG,EAAE,IAAI,EAAE,GAAG,EAAE,SAAS,EAAkB,IAAI,EAAE,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAGjG,OAAO,EAAC,UAAU,EAAC,MAAM,UAAU,CAAC;AACpC,OAAO,EAAO,oBAAoB,EAAiB,MAAM,QAAQ,CAAC;AAClE,OAAO,KAAK,aAAa,MAAM,uBAAuB,CAAC;AAEvD,oDAAoD;AACpD,MAAM,CAAN,IAAY,qBAGX;AAHD,WAAY,qBAAqB;IAC/B,qEAAU,CAAA;IACV,uEAAW,CAAA;AACb,CAAC,EAHW,qBAAqB,KAArB,qBAAqB,QAGhC;AAED,mEAAmE;AACnE,MAAM,CAAC,MAAM,sBAAsB,GAAG,GAAG,CAAC;AAQ1C;;;;;;;;;;;;;;;;;GAiBG;AACH,MAAM,OAAgB,YAAY;IAAlC;QACE,iDAAiD;QACjD,mBAAc,GAAoB,IAAI,CAAC;IAgCzC,CAAC;IA1BC,SAAS,CAAC,MAAc;QACtB,IAAI,CAAC,MAAM,GAAG,MAAM,CAAC;IACvB,CAAC;IAED,KAAK,CAAC,YAAY,CAAC,KAAa,EAAE,IAAqB,IAAG,CAAC;IAE3D,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB,IAAG,CAAC;IAEzD,KAAK,CAAC,YAAY,CAAC,KAAa,EAAE,IAAqB,IAAG,CAAC;IAE3D,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB,IAAG,CAAC;IAEzD,KAAK,CAAC,YAAY,CAAC,IAAqB,IAAG,CAAC;IAE5C,KAAK,CAAC,UAAU,CAAC,IAAqB,IAAG,CAAC;IAE1C,4EAA4E;IAC5E,8EAA8E;IAC9E,8EAA8E;IAC9E,0EAA0E;IAC1E,8EAA8E;IAC9E,mEAAmE;IACnE,eAAe;IACf,QAAQ,CAAC,KAAgB;QACvB,uEAAuE;IACzE,CAAC;CACF;AAED;;GAEG;AACH,MAAM,OAAO,YAAY;IAIvB,sEAAsE;IACtE,uCAAuC;IACvC,+BAA+B;IAC/B,4CAA4C;IAC5C,0CAA0C;IAE1C;;;;;OAKG;IACH,YAAY,SAA0B,EAAE,WAAW,GAAG,EAAE;QACtD,2EAA2E;QAC3E,UAAU;QACV,IAAI,SAAS,IAAI,IAAI,EAAE;YACrB,SAAS,GAAG,EAAE,CAAC;SAChB;QACD,IAAI,CAAC,SAAS,GAAG,SAAS,CAAC;QAC3B,IAAI,CAAC,WAAW,GAAG,WAAW,CAAC;IACjC,CAAC;IAED,MAAM,CAAC,QAAsB;QAC3B,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;IAChC,CAAC;IAED,SAAS,CAAC,MAAc;QACtB,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,QAAQ,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC;SAC5B;IACH,CAAC;IAED,QAAQ,CAAC,KAAgB;QACvB,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,QAAQ,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC;SAC1B;IACH,CAAC;IAED;;;;OAIG;IACH,KAAK,CAAC,YAAY,CAAC,KAAa,EAAE,IAAqB;QACrD,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,MAAM,QAAQ,CAAC,YAAY,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;SAC1C;IACH,CAAC;IAED;;;;OAIG;IACH,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB;QACnD,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,MAAM,QAAQ,CAAC,UAAU,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;SACxC;IACH,CAAC;IAED;;;;OAIG;IACH,KAAK,CAAC,YAAY,CAAC,KAAa,EAAE,IAAqB;QACrD,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,MAAM,QAAQ,CAAC,YAAY,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;SAC1C;IACH,CAAC;IAED;;;;OAIG;IACH,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB;QACnD,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,MAAM,QAAQ,CAAC,UAAU,CAAC,KAAK,EAAE,IAAI,CAAC,CAAC;SACxC;IACH,CAAC;IAED;;;OAGG;IACH,KAAK,CAAC,YAAY,CAAC,IAAqB;QACtC,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,MAAM,QAAQ,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;SACnC;IACH,CAAC;IAED;;;OAGG;IACH,KAAK,CAAC,UAAU,CAAC,IAAqB;QACpC,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,KAAK,MAAM,QAAQ,IAAI,IAAI,CAAC,SAAS,EAAE;YACrC,MAAM,QAAQ,CAAC,UAAU,CAAC,IAAI,CAAC,CAAC;SACjC;IACH,CAAC;CACF;AAED;;;;GAIG;AACH,MAAM,OAAO,UAAW,SAAQ,YAAY;IAI1C;QACE,KAAK,EAAE,CAAC;IACV,CAAC;IAEQ,KAAK,CAAC,YAAY,CAAC,KAAa;QACvC,IAAI,CAAC,IAAI,GAAG,CAAC,CAAC;QACd,IAAI,CAAC,MAAM,GAAG,EAAE,CAAC;IACnB,CAAC;IAEQ,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB;QAC5D,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,MAAM,SAAS,GAAG,IAAI,CAAC,MAAM,CAAC,IAAI,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,MAAM,CAAW,CAAC;QACpE,IAAI,CAAC,IAAI,IAAI,SAAS,CAAC;QACvB,KAAK,MAAM,GAAG,IAAI,IAAI,EAAE;YACtB,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC;YACxB,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE;gBAC7B,IAAI,CAAC,IAAI,CAAC,MAAM,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE;oBACpC,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;iBACtB;gBACD,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAW,GAAG,KAAK,GAAG,SAAS,CAAC;aACnE;iBAAM;gBACL,IAAI,kBAA0B,CAAC;gBAC/B,IAAI,GAAG,IAAI,IAAI,CAAC,MAAM,EAAE;oBACtB,kBAAkB,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAW,CAAC;iBACjD;qBAAM;oBACL,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;iBACtB;gBACD,MAAM,KAAK,GACP,IAAI,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,CAAC,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,EAAE,GAAG,CAAC,KAAK,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC;gBAC/D,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,GAAG,KAAK,CAAC;gBACzB,IAAI,kBAAkB,IAAI,IAAI,EAAE;oBAC9B,kBAAkB,CAAC,OAAO,EAAE,CAAC;iBAC9B;aACF;SACF;IACH,CAAC;IAEQ,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB;QAC5D,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,KAAK,MAAM,GAAG,IAAI,IAAI,CAAC,MAAM,CAAC,SAAS,CAAa,EAAE;gBACpD,IAAI,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,IAAI,IAAI,EAAE;oBAC5B,SAAS;iBACV;gBACD,IAAI,OAAO,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,KAAK,QAAQ,EAAE;oBACxC,IAAI,CAAC,GAAG,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,GAAG,CAAW,GAAG,IAAI,CAAC,IAAI,CAAC;iBACpD;qBAAM;oBACL,IAAI,CAAC,GAAG,EAAE;wBACR,MAAM,GAAG,GAAW,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,EAAE,IAAI,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC;wBAC7D,IAAI,CAAC,GAAG,CAAC,GAAG,GAAG,CAAC;wBACf,IAAI,CAAC,MAAM,CAAC,GAAG,CAAY,CAAC,OAAO,EAAE,CAAC;wBACvC,IAAI,CAAC,IAAI,CAAC,GAAG,CAAW,CAAC,CAAC;oBAC5B,CAAC,CAAC,CAAC;iBACJ;aACF;SACF;IACH,CAAC;CACF;AAED;;;;GAIG;AACH,MAAM,OAAO,OAAQ,SAAQ,YAAY;IAI9B,KAAK,CAAC,YAAY,CAAC,IAAqB;QAC/C,IAAI,CAAC,KAAK,GAAG,EAAE,CAAC;QAChB,IAAI,CAAC,OAAO,GAAG,EAAE,CAAC;IACpB,CAAC;IAEQ,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB;QAC5D,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QACvB,KAAK,MAAM,GAAG,IAAI,IAAI,EAAE;YACtB,IAAI,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,IAAI,IAAI,EAAE;gBAC7B,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC;aACxB;YACD,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC;SACnC;IACH,CAAC;IAED;;OAEG;IACH,KAAK,CAAC,QAAQ;QACZ,MAAM,QAAQ,GAAuD,EAAE,CAAC;QACxE,MAAM,IAAI,GAAa,EAAE,CAAC;QAC1B,MAAM,OAAO,GAAa,EAAE,CAAC;QAC7B,KAAK,MAAM,GAAG,IAAI,IAAI,CAAC,OAAO,EAAE;YAC9B,MAAM,UAAU,GAAG,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC;YACrC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;gBAC1C,IAAI,OAAO,UAAU,CAAC,CAAC,CAAC,KAAK,QAAQ,EAAE;oBACrC,MAAM,WAAW,GAAG,UAAU,CAAC,CAAC,CAAW,CAAC;oBAC5C,QAAQ,CAAC,IAAI,CAAC,WAAW,CAAC,IAAI,EAAE,CAAC,CAAC;oBAClC,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;oBACf,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;iBACjB;aACF;SACF;QACD,MAAM,MAAM,GAAG,MAAM,OAAO,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC;QAC3C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;YACtC,MAAM,eAAe,GAAG,IAAI,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAW,CAAC;YACpE,eAAe,CAAC,OAAO,EAAE,CAAC;YAC1B,IAAI,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;SAClD;IACH,CAAC;CACF;AAeD;;GAEG;AACH,MAAM,OAAO,cAAe,SAAQ,YAAY;IAmB9C,YAAY,IAAwB,EAAE,UAA8B;QAClE,KAAK,EAAE,CAAC;QALF,iBAAY,GAAG,CAAC,CAAC;QAMvB,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,OAAO,CAAC;QAC5B,IAAI,CAAC,aAAa,GAAG,IAAI,CAAC,aAAa,IAAI,SAAS,CAAC;QACrD,IAAI,CAAC,UAAU,GAAG,UAAU,IAAI,MAAM,CAAC;QACvC,IAAI,IAAI,CAAC,UAAU,KAAK,MAAM,EAAE;YAC9B,IAAI,CAAC,UAAU,GAAG,sBAAsB,CAAC;SAC1C;QACD,IAAI,IAAI,CAAC,UAAU,KAAK,OAAO,IAAI,IAAI,CAAC,OAAO,IAAI,IAAI,EAAE;YACvD,MAAM,IAAI,KAAK,CACX,gEAAgE;gBAChE,mDAAmD,CAAC,CAAC;SAC1D;QACD,IAAI,IAAI,CAAC,QAAQ,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE;YAClC,+DAA+D;YAC/D,mBAAmB;YACnB,IAAI,CAAC,SAAS,GAAG,aAAa,CAAC,QAAQ,CACnC,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,IAAI,CAAC,EAAE,IAAI,CAAC,UAAoB,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC;SACzE;QACD,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,YAAY,CAAC;QACpC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,UAAU,CAAC;QAChC,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,YAAY,CAAC;QACpC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,UAAU,CAAC;QAChC,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,YAAY,CAAC;QACpC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,UAAU,CAAC;QAChC,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,OAAO,CAAC;IAC5B,CAAC;IAED,KAAK,CAAC,SAAS,CAAC,KAAa,EAAE,KAAa,EAAE,IAAoB;QAChE,MAAM,EAAE,GAA8B,EAAE,CAAC;QACzC,IAAI,IAAI,CAAC,KAAK,IAAI,IAAI,EAAE;YACtB,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;YACjC,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,EAAE,KAAK,EAAE,IAAY,CAAC,CAAC,CAAC;SACjD;QACD,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,aAAa,EAAE,CAAC,CAAC;QAC9B,MAAM,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;IACxB,CAAC;IAEQ,KAAK,CAAC,YAAY,CAAC,KAAa,EAAE,IAAqB;QAE9D,IAAI,CAAC,YAAY,GAAG,KAAK,CAAC;QAC1B,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;YACjC,MAAM,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,IAAY,CAAC,CAAC;SAC5C;IACH,CAAC;IAEQ,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB;QAE5D,MAAM,EAAE,GAA8B,EAAE,CAAC;QACzC,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;YACjC,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,CAAC,KAAK,EAAE,IAAY,CAAC,CAAC,CAAC;SAC7C;QACD,IAAI,IAAI,CAAC,UAAU,KAAK,OAAO,EAAE;YAC/B,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,aAAa,EAAE,CAAC,CAAC;SAC/B;QACD,MAAM,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;IACxB,CAAC;IAEQ,KAAK,CAAC,YAAY,CAAC,KAAa,EAAE,IAAqB;QAE9D,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;YACjC,MAAM,IAAI,CAAC,UAAU,CAAC,KAAK,EAAE,IAAY,CAAC,CAAC;SAC5C;IACH,CAAC;IAEQ,KAAK,CAAC,UAAU,CAAC,KAAa,EAAE,IAAqB;QAE5D,MAAM,EAAE,GAA8B,EAAE,CAAC;QACzC,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;YACjC,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,CAAC,KAAK,EAAE,IAAY,CAAC,CAAC,CAAC;SAC7C;QACD,IAAI,IAAI,CAAC,UAAU,KAAK,OAAO,EAAE;YAC/B,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,aAAa,EAAE,CAAC,CAAC;SAC/B;aAAM,IAAI,IAAI,CAAC,QAAQ,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE;YACzC,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,YAAY,EAAE,KAAK,EAAE,IAAI,CAAC,CAAC,CAAC;SACzD;QACD,MAAM,OAAO,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;IACxB,CAAC;IAEQ,KAAK,CAAC,YAAY,CAAC,IAAqB;QAC/C,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;YACjC,MAAM,IAAI,CAAC,UAAU,CAAC,IAAY,CAAC,CAAC;SACrC;IACH,CAAC;IAEQ,KAAK,CAAC,UAAU,CAAC,IAAqB;QAC7C,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;YACjC,MAAM,IAAI,CAAC,QAAQ,CAAC,IAAY,CAAC,CAAC;SACnC;IACH,CAAC;CACF;AAED;;GAEG;AACH,MAAM,UAAU,oBAAoB,CAChC,SACoB,EACpB,UAA6B;IAC/B,IAAI,SAAS,IAAI,IAAI,EAAE;QACrB,SAAS,GAAG,EAAkB,CAAC;KAChC;IACD,IAAI,SAAS,YAAY,YAAY,EAAE;QACrC,OAAO,CAAC,SAAS,CAAC,CAAC;KACpB;IACD,IAAI,KAAK,CAAC,OAAO,CAAC,SAAS,CAAC,IAAI,SAAS,CAAC,CAAC,CAAC,YAAY,YAAY,EAAE;QACpE,OAAO,SAA2B,CAAC;KACpC;IACD,8DAA8D;IAC9D,MAAM,eAAe,GACjB,aAAa,CAAC,MAAM,CAClB,SAAS,CAAyB,CAAC;IACzC,OAAO,eAAe,CAAC,GAAG,CACtB,cAAc,CAAC,EAAE,CAAC,IAAI,cAAc,CAAC,cAAc,EAAE,UAAU,CAAC,CAAC,CAAC;AACxE,CAAC;AAMD;;;GAGG;AACH,MAAa,2BAA2B;IAItC;;OAEG;IACH,gBAAuB,CAAC;IAExB;;;;;;;;;;;OAWG;IACH,MAAM,CAAC,2BAA2B,CAC9B,cAAsB,EAAE,mBAA4C;QACtE,IAAI,CAAC,MAAM,CACP,cAAc,IAAI,CAAC,IAAI,MAAM,CAAC,SAAS,CAAC,cAAc,CAAC,EACvD,GAAG,EAAE,CAAC,qDAAqD;YACvD,WAAW,cAAc,EAAE,CAAC,CAAC;QACrC,2BAA2B,CAAC,iBAAiB,CAAC,mBAAmB,CAAC,CAAC;QACnE,IAAI,2BAA2B,CAAC,YAAY,CAAC,cAAc,CAAC,IAAI,IAAI,EAAE;YACpE,2BAA2B,CAAC,YAAY,CAAC,cAAc,CAAC,GAAG,EAAE,CAAC;SAC/D;QACD,2BAA2B,CAAC,YAAY,CAAC,cAAc,CAAC,CAAC,IAAI,CACzD,mBAAmB,CAAC,CAAC;IAC3B,CAAC;IAEO,MAAM,CAAC,iBAAiB,CAAC,mBAC2B;QAC1D,KAAK,MAAM,SAAS,IAAI,2BAA2B,CAAC,YAAY,EAAE;YAChE,MAAM,YAAY,GAAG,2BAA2B,CAAC,YAAY,CAAC,CAAC,SAAS,CAAC,CAAC;YAC1E,YAAY,CAAC,OAAO,CAAC,IAAI,CAAC,EAAE;gBAC1B,IAAI,IAAI,KAAK,mBAAmB,EAAE;oBAChC,MAAM,IAAI,UAAU,CAAC,iCAAiC,CAAC,CAAC;iBACzD;YACH,CAAC,CAAC,CAAC;SACJ;IACH,CAAC;IAED;;OAEG;IACO,MAAM,CAAC,KAAK;QACpB,2BAA2B,CAAC,YAAY,GAAG,EAAE,CAAC;IAChD,CAAC;IAED;;;;;;;OAOG;IACH,MAAM,CAAC,eAAe,CAAC,cAAsB;QAC3C,MAAM,YAAY,GAA8B,EAAE,CAAC;QACnD,KAAK,MAAM,SAAS,IAAI,2BAA2B,CAAC,YAAY,EAAE;YAChE,MAAM,KAAK,GAAG,CAAC,SAAS,CAAC;YACzB,IAAI,cAAc,IAAI,KAAK,EAAE;gBAC3B,YAAY,CAAC,IAAI,CAAC,GAAG,2BAA2B,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC,CAAC;aACvE;SACF;QACD,OAAO,YAAY,CAAC,GAAG,CAAC,IAAI,CAAC,EAAE,CAAC,IAAI,IAAI,EAAE,CAAC,CAAC;IAC9C,CAAC;;AAtEc,wCAAY,GACiC,EAAE,CAAC;SAFpD,2BAA2B;AA0ExC,MAAM,UAAU,kBAAkB,CAC9B,SAAyB,EAAE,OAA8B,EAAE,MAAc,EACzE,YAAoB,EAAE,eAAuB,EAAE,aAAqB,EACpE,SAAiB,EAAE,YAAqB,EACxC,eAAyB;IAC3B,MAAM,OAAO,GAAG,IAAI,OAAO,EAAE,CAAC;IAC9B,MAAM,eAAe,GAAmB;QACtC,IAAI,UAAU,EAAE,EAAE,GAAG,2BAA2B,CAAC,eAAe,CAAC,OAAO,CAAC;KAC1E,CAAC;IACF,IAAI,SAAS,IAAI,IAAI,EAAE;QACrB,eAAe,CAAC,IAAI,CAAC,GAAG,SAAS,CAAC,CAAC;KACpC;IACD,eAAe,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;IAC9B,MAAM,YAAY,GAAG,IAAI,YAAY,CAAC,eAAe,CAAC,CAAC;IAEvD,mEAAmE;IACnE,cAAc;IACd,wDAAwD;IAExD,YAAY,CAAC,SAAS,CAAC;QACrB,MAAM;QACN,YAAY;QACZ,OAAO,EAAE,eAAe;QACxB,KAAK,EAAE,aAAa;QACpB,SAAS;QACT,OAAO;QACP,YAAY;QACZ,OAAO,EAAE,eAAe;KACzB,CAAC,CAAC;IACH,OAAO,EAAC,YAAY,EAAE,OAAO,EAAC,CAAC;AACjC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/* Original source: keras/callbacks.py */\n\nimport {add, div, keep, mul, nextFrame, Scalar, Tensor, tidy, util} from '@tensorflow/tfjs-core';\n\nimport {Container} from './engine/container';\nimport {ValueError} from './errors';\nimport {Logs, resolveScalarsInLogs, UnresolvedLogs} from './logs';\nimport * as generic_utils from './utils/generic_utils';\n\n/** Verbosity logging level when fitting a model. */\nexport enum ModelLoggingVerbosity {\n  SILENT = 0,\n  VERBOSE = 1\n}\n\n/** How often to yield to the main thread when training (in ms). */\nexport const DEFAULT_YIELD_EVERY_MS = 125;\n\nexport type Params = {\n  [key: string]: number|string|boolean|number[]|string[]|boolean[];\n};\n\nexport type YieldEveryOptions = 'auto'|'batch'|'epoch'|'never'|number;\n\n/**\n * Abstract base class used to build new callbacks.\n *\n * The `logs` dictionary that callback methods take as argument will contain\n * keys for quantities relevant to the current batch or epoch.\n *\n * Currently, the `.fit()` method of the `Sequential` model class\n * will include the following quantities in the `logs` that\n * it passes to its callbacks:\n *\n * onEpochEnd: Logs include `acc` and `loss`, and optionally include `valLoss`\n *   (if validation is enabled in `fit`), and `valAcc` (if validation and\n *   accuracy monitoring are enabled).\n * onBatchBegin: Logs include `size`, the number of samples in the current\n *   batch.\n * onBatchEnd: Logs include `loss`, and optionally `acc` (if accuracy monitoring\n *   is enabled).\n */\nexport abstract class BaseCallback {\n  // TODO(michaelterry): This type is a best guess.\n  validationData: Tensor|Tensor[] = null;\n  /**\n   * Training parameters (eg. verbosity, batch size, number of epochs...).\n   */\n  params: Params;\n\n  setParams(params: Params): void {\n    this.params = params;\n  }\n\n  async onEpochBegin(epoch: number, logs?: UnresolvedLogs) {}\n\n  async onEpochEnd(epoch: number, logs?: UnresolvedLogs) {}\n\n  async onBatchBegin(batch: number, logs?: UnresolvedLogs) {}\n\n  async onBatchEnd(batch: number, logs?: UnresolvedLogs) {}\n\n  async onTrainBegin(logs?: UnresolvedLogs) {}\n\n  async onTrainEnd(logs?: UnresolvedLogs) {}\n\n  // LayersModel needs to call Callback.setModel(), but cannot actually depend\n  // on Callback because that creates a cyclic dependency.  Providing this no-op\n  // method on BaseCallback breaks the cycle: this way LayersModel can depend on\n  // BaseCallback but not on Callback.  The argument is typed as `Container`\n  // (the superclass of LayersModel) to avoid recapitulating the cycle. Callback\n  // overrides this method and enforces that the argument is really a\n  // LayersModel.\n  setModel(model: Container): void {\n    // Do nothing. Use Callback instead of BaseCallback to track the model.\n  }\n}\n\n/**\n * Container abstracting a list of callbacks.\n */\nexport class CallbackList {\n  callbacks: BaseCallback[];\n  queueLength: number;\n\n  // TODO(cais): When the need arises, uncomment the following lines and\n  // implement the queue for time values.\n  // private deltaTBatch: number;\n  // private deltaTsBatchBegin: Array<number>;\n  // private deltaTsBatchEnd: Array<number>;\n\n  /**\n   * Constructor of CallbackList.\n   * @param callbacks Array of `Callback` instances.\n   * @param queueLength Queue length for keeping running statistics over\n   *   callback execution time.\n   */\n  constructor(callbacks?: BaseCallback[], queueLength = 10) {\n    // TODO(cais): Make use of queueLength when implementing the queue for time\n    // values.\n    if (callbacks == null) {\n      callbacks = [];\n    }\n    this.callbacks = callbacks;\n    this.queueLength = queueLength;\n  }\n\n  append(callback: BaseCallback): void {\n    this.callbacks.push(callback);\n  }\n\n  setParams(params: Params): void {\n    for (const callback of this.callbacks) {\n      callback.setParams(params);\n    }\n  }\n\n  setModel(model: Container): void {\n    for (const callback of this.callbacks) {\n      callback.setModel(model);\n    }\n  }\n\n  /**\n   * Called at the start of an epoch.\n   * @param epoch Index of epoch.\n   * @param logs Dictionary of logs.\n   */\n  async onEpochBegin(epoch: number, logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    for (const callback of this.callbacks) {\n      await callback.onEpochBegin(epoch, logs);\n    }\n  }\n\n  /**\n   * Called at the end of an epoch.\n   * @param epoch Index of epoch.\n   * @param logs Dictionary of logs.\n   */\n  async onEpochEnd(epoch: number, logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    for (const callback of this.callbacks) {\n      await callback.onEpochEnd(epoch, logs);\n    }\n  }\n\n  /**\n   * Called  right before processing a batch.\n   * @param batch Index of batch within the current epoch.\n   * @param logs Dictionary of logs.\n   */\n  async onBatchBegin(batch: number, logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    for (const callback of this.callbacks) {\n      await callback.onBatchBegin(batch, logs);\n    }\n  }\n\n  /**\n   * Called at the end of a batch.\n   * @param batch Index of batch within the current epoch.\n   * @param logs Dictionary of logs.\n   */\n  async onBatchEnd(batch: number, logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    for (const callback of this.callbacks) {\n      await callback.onBatchEnd(batch, logs);\n    }\n  }\n\n  /**\n   * Called at the beginning of training.\n   * @param logs Dictionary of logs.\n   */\n  async onTrainBegin(logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    for (const callback of this.callbacks) {\n      await callback.onTrainBegin(logs);\n    }\n  }\n\n  /**\n   * Called at the end of training.\n   * @param logs Dictionary of logs.\n   */\n  async onTrainEnd(logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    for (const callback of this.callbacks) {\n      await callback.onTrainEnd(logs);\n    }\n  }\n}\n\n/**\n * Callback that accumulates epoch averages of metrics.\n *\n * This callback is automatically applied to every LayersModel.\n */\nexport class BaseLogger extends BaseCallback {\n  private seen: number;\n  private totals: UnresolvedLogs;\n\n  constructor() {\n    super();\n  }\n\n  override async onEpochBegin(epoch: number) {\n    this.seen = 0;\n    this.totals = {};\n  }\n\n  override async onBatchEnd(batch: number, logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    const batchSize = logs['size'] == null ? 0 : logs['size'] as number;\n    this.seen += batchSize;\n    for (const key in logs) {\n      const value = logs[key];\n      if (typeof value === 'number') {\n        if (!this.totals.hasOwnProperty(key)) {\n          this.totals[key] = 0;\n        }\n        this.totals[key] = this.totals[key] as number + value * batchSize;\n      } else {\n        let oldTotalsToDispose: Scalar;\n        if (key in this.totals) {\n          oldTotalsToDispose = this.totals[key] as Scalar;\n        } else {\n          this.totals[key] = 0;\n        }\n        const total: Scalar =\n            tidy(() => add((this.totals[key]), mul(value, batchSize)));\n        this.totals[key] = total;\n        if (oldTotalsToDispose != null) {\n          oldTotalsToDispose.dispose();\n        }\n      }\n    }\n  }\n\n  override async onEpochEnd(epoch: number, logs?: UnresolvedLogs) {\n    if (logs != null) {\n      for (const key of this.params['metrics'] as string[]) {\n        if (this.totals[key] == null) {\n          continue;\n        }\n        if (typeof this.totals[key] === 'number') {\n          logs[key] = this.totals[key] as number / this.seen;\n        } else {\n          tidy(() => {\n            const log: Scalar = mul(div(1, this.seen), this.totals[key]);\n            logs[key] = log;\n            (this.totals[key] as Tensor).dispose();\n            keep(logs[key] as Scalar);\n          });\n        }\n      }\n    }\n  }\n}\n\n/**\n * Callback that records events into a `History` object. This callback is\n * automatically applied to every TF.js Layers model. The `History` object\n * gets returned by the `fit` method of models.\n */\nexport class History extends BaseCallback {\n  epoch: number[];\n  history: {[key: string]: Array<number|Tensor>};\n\n  override async onTrainBegin(logs?: UnresolvedLogs) {\n    this.epoch = [];\n    this.history = {};\n  }\n\n  override async onEpochEnd(epoch: number, logs?: UnresolvedLogs) {\n    if (logs == null) {\n      logs = {};\n    }\n    this.epoch.push(epoch);\n    for (const key in logs) {\n      if (this.history[key] == null) {\n        this.history[key] = [];\n      }\n      this.history[key].push(logs[key]);\n    }\n  }\n\n  /**\n   * Await the values of all losses and metrics.\n   */\n  async syncData() {\n    const promises: Array<Promise<Float32Array|Int32Array|Uint8Array>> = [];\n    const keys: string[] = [];\n    const indices: number[] = [];\n    for (const key in this.history) {\n      const valueArray = this.history[key];\n      for (let i = 0; i < valueArray.length; ++i) {\n        if (typeof valueArray[i] !== 'number') {\n          const valueScalar = valueArray[i] as Tensor;\n          promises.push(valueScalar.data());\n          keys.push(key);\n          indices.push(i);\n        }\n      }\n    }\n    const values = await Promise.all(promises);\n    for (let n = 0; n < values.length; ++n) {\n      const tensorToDispose = this.history[keys[n]][indices[n]] as Tensor;\n      tensorToDispose.dispose();\n      this.history[keys[n]][indices[n]] = values[n][0];\n    }\n  }\n}\n\nexport interface CustomCallbackArgs {\n  onTrainBegin?: (logs?: Logs) => void | Promise<void>;\n  onTrainEnd?: (logs?: Logs) => void | Promise<void>;\n  onEpochBegin?: (epoch: number, logs?: Logs) => void | Promise<void>;\n  onEpochEnd?: (epoch: number, logs?: Logs) => void | Promise<void>;\n  onBatchBegin?: (batch: number, logs?: Logs) => void | Promise<void>;\n  onBatchEnd?: (batch: number, logs?: Logs) => void | Promise<void>;\n  onYield?: (epoch: number, batch: number, logs: Logs) => void | Promise<void>;\n  // Used for test DI mocking.\n  nowFunc?: Function;\n  nextFrameFunc?: Function;\n}\n\n/**\n * Custom callback for training.\n */\nexport class CustomCallback extends BaseCallback {\n  protected readonly trainBegin: (logs?: Logs) => void | Promise<void>;\n  protected readonly trainEnd: (logs?: Logs) => void | Promise<void>;\n  protected readonly epochBegin:\n      (epoch: number, logs?: Logs) => void | Promise<void>;\n  protected readonly epochEnd:\n      (epoch: number, logs?: Logs) => void | Promise<void>;\n  protected readonly batchBegin:\n      (batch: number, logs?: Logs) => void | Promise<void>;\n  protected readonly batchEnd:\n      (batch: number, logs?: Logs) => void | Promise<void>;\n  protected readonly yield:\n      (epoch: number, batch: number, logs: Logs) => void | Promise<void>;\n\n  private yieldEvery: YieldEveryOptions;\n  private currentEpoch = 0;\n  public nowFunc: Function;\n  public nextFrameFunc: Function;\n\n  constructor(args: CustomCallbackArgs, yieldEvery?: YieldEveryOptions) {\n    super();\n    this.nowFunc = args.nowFunc;\n    this.nextFrameFunc = args.nextFrameFunc || nextFrame;\n    this.yieldEvery = yieldEvery || 'auto';\n    if (this.yieldEvery === 'auto') {\n      this.yieldEvery = DEFAULT_YIELD_EVERY_MS;\n    }\n    if (this.yieldEvery === 'never' && args.onYield != null) {\n      throw new Error(\n          'yieldEvery is `never` but you provided an `onYield` callback. ' +\n          'Either change `yieldEvery` or remove the callback');\n    }\n    if (util.isNumber(this.yieldEvery)) {\n      // Decorate `maybeWait` so it will be called at most once every\n      // `yieldEvery` ms.\n      this.maybeWait = generic_utils.debounce(\n          this.maybeWait.bind(this), this.yieldEvery as number, this.nowFunc);\n    }\n    this.trainBegin = args.onTrainBegin;\n    this.trainEnd = args.onTrainEnd;\n    this.epochBegin = args.onEpochBegin;\n    this.epochEnd = args.onEpochEnd;\n    this.batchBegin = args.onBatchBegin;\n    this.batchEnd = args.onBatchEnd;\n    this.yield = args.onYield;\n  }\n\n  async maybeWait(epoch: number, batch: number, logs: UnresolvedLogs) {\n    const ps: Array<void|Promise<void>> = [];\n    if (this.yield != null) {\n      await resolveScalarsInLogs(logs);\n      ps.push(this.yield(epoch, batch, logs as Logs));\n    }\n    ps.push(this.nextFrameFunc());\n    await Promise.all(ps);\n  }\n\n  override async onEpochBegin(epoch: number, logs?: UnresolvedLogs):\n      Promise<void> {\n    this.currentEpoch = epoch;\n    if (this.epochBegin != null) {\n      await resolveScalarsInLogs(logs);\n      await this.epochBegin(epoch, logs as Logs);\n    }\n  }\n\n  override async onEpochEnd(epoch: number, logs?: UnresolvedLogs):\n      Promise<void> {\n    const ps: Array<void|Promise<void>> = [];\n    if (this.epochEnd != null) {\n      await resolveScalarsInLogs(logs);\n      ps.push(this.epochEnd(epoch, logs as Logs));\n    }\n    if (this.yieldEvery === 'epoch') {\n      ps.push(this.nextFrameFunc());\n    }\n    await Promise.all(ps);\n  }\n\n  override async onBatchBegin(batch: number, logs?: UnresolvedLogs):\n      Promise<void> {\n    if (this.batchBegin != null) {\n      await resolveScalarsInLogs(logs);\n      await this.batchBegin(batch, logs as Logs);\n    }\n  }\n\n  override async onBatchEnd(batch: number, logs?: UnresolvedLogs):\n      Promise<void> {\n    const ps: Array<void|Promise<void>> = [];\n    if (this.batchEnd != null) {\n      await resolveScalarsInLogs(logs);\n      ps.push(this.batchEnd(batch, logs as Logs));\n    }\n    if (this.yieldEvery === 'batch') {\n      ps.push(this.nextFrameFunc());\n    } else if (util.isNumber(this.yieldEvery)) {\n      ps.push(this.maybeWait(this.currentEpoch, batch, logs));\n    }\n    await Promise.all(ps);\n  }\n\n  override async onTrainBegin(logs?: UnresolvedLogs): Promise<void> {\n    if (this.trainBegin != null) {\n      await resolveScalarsInLogs(logs);\n      await this.trainBegin(logs as Logs);\n    }\n  }\n\n  override async onTrainEnd(logs?: UnresolvedLogs): Promise<void> {\n    if (this.trainEnd != null) {\n      await resolveScalarsInLogs(logs);\n      await this.trainEnd(logs as Logs);\n    }\n  }\n}\n\n/**\n * Standardize callbacks or configurations of them to an Array of callbacks.\n */\nexport function standardizeCallbacks(\n    callbacks: BaseCallback|BaseCallback[]|CustomCallbackArgs|\n    CustomCallbackArgs[],\n    yieldEvery: YieldEveryOptions): BaseCallback[] {\n  if (callbacks == null) {\n    callbacks = {} as BaseCallback;\n  }\n  if (callbacks instanceof BaseCallback) {\n    return [callbacks];\n  }\n  if (Array.isArray(callbacks) && callbacks[0] instanceof BaseCallback) {\n    return callbacks as BaseCallback[];\n  }\n  // Convert custom callback configs to custom callback objects.\n  const callbackConfigs =\n      generic_utils.toList<BaseCallback | CustomCallbackArgs>(\n        callbacks) as CustomCallbackArgs[];\n  return callbackConfigs.map(\n      callbackConfig => new CustomCallback(callbackConfig, yieldEvery));\n}\n\nexport declare type BaseCallbackConstructor = {\n  new (): BaseCallback\n};\n\n/**\n * A global registry for callback constructors to be used during\n * LayersModel.fit().\n */\nexport class CallbackConstructorRegistry {\n  private static constructors:\n      {[verbosityLevel: number]: BaseCallbackConstructor[]} = {};\n\n  /**\n   * Blocks public access to constructor.\n   */\n  private constructor() {}\n\n  /**\n   * Register a tf.LayersModel.fit() callback constructor.\n   *\n   * The registered callback constructor will be used to instantiate\n   * callbacks for every tf.LayersModel.fit() call afterwards.\n   *\n   * @param verbosityLevel Level of verbosity at which the `callbackConstructor`\n   *   is to be reigstered.\n   * @param callbackConstructor A no-arg constructor for `tf.Callback`.\n   * @throws Error, if the same callbackConstructor has been registered before,\n   *   either at the same or a different `verbosityLevel`.\n   */\n  static registerCallbackConstructor(\n      verbosityLevel: number, callbackConstructor: BaseCallbackConstructor) {\n    util.assert(\n        verbosityLevel >= 0 && Number.isInteger(verbosityLevel),\n        () => `Verbosity level is expected to be an integer >= 0, ` +\n            `but got ${verbosityLevel}`);\n    CallbackConstructorRegistry.checkForDuplicate(callbackConstructor);\n    if (CallbackConstructorRegistry.constructors[verbosityLevel] == null) {\n      CallbackConstructorRegistry.constructors[verbosityLevel] = [];\n    }\n    CallbackConstructorRegistry.constructors[verbosityLevel].push(\n        callbackConstructor);\n  }\n\n  private static checkForDuplicate(callbackConstructor:\n                                       BaseCallbackConstructor) {\n    for (const levelName in CallbackConstructorRegistry.constructors) {\n      const constructors = CallbackConstructorRegistry.constructors[+levelName];\n      constructors.forEach(ctor => {\n        if (ctor === callbackConstructor) {\n          throw new ValueError('Duplicate callback constructor.');\n        }\n      });\n    }\n  }\n\n  /**\n   * Clear all registered callback constructors.\n   */\n  protected static clear() {\n    CallbackConstructorRegistry.constructors = {};\n  }\n\n  /**\n   * Create callbacks using the registered callback constructors.\n   *\n   * Given `verbosityLevel`, all constructors registered at that level or above\n   * will be called and the instantiated callbacks will be used.\n   *\n   * @param verbosityLevel: Level of verbosity.\n   */\n  static createCallbacks(verbosityLevel: number): BaseCallback[] {\n    const constructors: BaseCallbackConstructor[] = [];\n    for (const levelName in CallbackConstructorRegistry.constructors) {\n      const level = +levelName;\n      if (verbosityLevel >= level) {\n        constructors.push(...CallbackConstructorRegistry.constructors[level]);\n      }\n    }\n    return constructors.map(ctor => new ctor());\n  }\n}\n\nexport function configureCallbacks(\n    callbacks: BaseCallback[], verbose: ModelLoggingVerbosity, epochs: number,\n    initialEpoch: number, numTrainSamples: number, stepsPerEpoch: number,\n    batchSize: number, doValidation: boolean,\n    callbackMetrics: string[]): {callbackList: CallbackList, history: History} {\n  const history = new History();\n  const actualCallbacks: BaseCallback[] = [\n    new BaseLogger(), ...CallbackConstructorRegistry.createCallbacks(verbose)\n  ];\n  if (callbacks != null) {\n    actualCallbacks.push(...callbacks);\n  }\n  actualCallbacks.push(history);\n  const callbackList = new CallbackList(actualCallbacks);\n\n  // TODO(cais): Figure out when this LayersModel instance can have a\n  // dynamically\n  //   set property called 'callback_model' as in PyKeras.\n\n  callbackList.setParams({\n    epochs,\n    initialEpoch,\n    samples: numTrainSamples,\n    steps: stepsPerEpoch,\n    batchSize,\n    verbose,\n    doValidation,\n    metrics: callbackMetrics,\n  });\n  return {callbackList, history};\n}\n"]}

@@ -105,2 +105,3 @@ /**

loadWeights(weights: NamedTensorMap, strict?: boolean): void;
protected parseWeights(weights: NamedTensorMap): void;
/**

@@ -107,0 +108,0 @@ * Util shared between different serialization methods.

@@ -629,2 +629,3 @@ /**

computeMask(inputs: Tensor | Tensor[], mask?: Tensor | Tensor[]): Tensor | Tensor[];
private setMaskMetadata;
/**

@@ -631,0 +632,0 @@ * Internal method to create an inbound node for the layer.

@@ -126,3 +126,3 @@ /**

*/
axis?: number;
axis?: number | number[];
}

@@ -132,3 +132,3 @@ export declare class Softmax extends Layer {

static className: string;
readonly axis: number;
readonly axis: number | number[];
readonly softmax: (t: Tensor, a?: number) => Tensor;

@@ -135,0 +135,0 @@ readonly DEFAULT_AXIS = 1;

@@ -13,3 +13,3 @@ /**

*/
import { cast, clipByValue, elu, greater, leakyRelu, mul, prelu, relu, serialization } from '@tensorflow/tfjs-core';
import { add, cast, clipByValue, elu, exp, greater, leakyRelu, logSumExp, mul, ones, prelu, relu, scalar, serialization, sub, tidy } from '@tensorflow/tfjs-core';
import { Softmax as softmaxActivation } from '../activations';

@@ -216,4 +216,25 @@ import { getConstraint, serializeConstraint } from '../constraints';

call(inputs, kwargs) {
const x = getExactlyOneTensor(inputs);
return this.softmax(x, this.axis);
// TODO(pforderique): Add tests for when `this.axis` is a number[].
return tidy(() => {
let x = getExactlyOneTensor(inputs);
const mask = kwargs['mask'];
if (mask != null) {
// Since mask is 1.0 for positions we want to keep and 0.0 for masked
// positions, this operation will create a tensor which is 0.0 for
// positions we want to attend and -1e.9 for masked positions.
const adder = mul(sub(ones(x.shape), cast(mask, x.dtype)), scalar(-1e9));
// Since we are adding it to the raw scores before the softmax, this
// is effectively the same as removing these entirely.
x = add(x, adder);
}
if (this.axis instanceof Array) {
if (this.axis.length > 1) {
return exp(sub(x, logSumExp(x, this.axis, true)));
}
else {
return this.softmax(x, this.axis[0]);
}
}
return this.softmax(x, this.axis);
});
}

@@ -234,2 +255,2 @@ computeOutputShape(inputShape) {

serialization.registerClass(Softmax);
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"advanced_activations.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/layers/advanced_activations.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH;;GAEG;AAEH,OAAO,EAAC,IAAI,EAAE,WAAW,EAAE,GAAG,EAAE,OAAO,EAAE,SAAS,EAAE,GAAG,EAAE,KAAK,EAAE,IAAI,EAAE,aAAa,EAAS,MAAM,uBAAuB,CAAC;AAE1H,OAAO,EAAC,OAAO,IAAI,iBAAiB,EAAC,MAAM,gBAAgB,CAAC;AAC5D,OAAO,EAAa,aAAa,EAAE,mBAAmB,EAAC,MAAM,gBAAgB,CAAC;AAC9E,OAAO,EAAC,SAAS,EAAE,KAAK,EAAY,MAAM,oBAAoB,CAAC;AAC/D,OAAO,EAAC,mBAAmB,EAAE,UAAU,EAAC,MAAM,WAAW,CAAC;AAC1D,OAAO,EAAC,cAAc,EAAsC,oBAAoB,EAAC,MAAM,iBAAiB,CAAC;AAEzG,OAAO,EAAC,cAAc,EAAe,oBAAoB,EAAC,MAAM,iBAAiB,CAAC;AAElF,OAAO,EAAC,kBAAkB,EAAE,mBAAmB,EAAC,MAAM,sBAAsB,CAAC;AAU7E,MAAa,IAAK,SAAQ,KAAK;IAK7B,YAAY,IAAoB;QAC9B,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAChC,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC;QAC5B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;SAC/B;IACH,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACrC,IAAI,MAAM,GAAG,IAAI,CAAC,MAAM,CAAC,CAAC;QAC1B,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,MAAM,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC;SAChD;QACD,OAAO,MAAM,CAAC;IAChB,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAC,CAAC;QACnE,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA9BD,kBAAkB;AACX,cAAS,GAAG,MAAM,CAAC;SAFf,IAAI;AAiCjB,aAAa,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;AASlC,MAAa,SAAU,SAAQ,KAAK;IAOlC,YAAY,IAAyB;QACnC,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,kBAAa,GAAG,GAAG,CAAC;QAI3B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,aAAa,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC;IACpE,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,CAAC,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACtC,OAAO,SAAS,CAAC,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;IAClC,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,KAAK,EAAE,IAAI,CAAC,KAAK,EAAC,CAAC;QAC7D,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA5BD,kBAAkB;AACX,mBAAS,GAAG,WAAW,AAAd,CAAe;SAFpB,SAAS;AA+BtB,aAAa,CAAC,aAAa,CAAC,SAAS,CAAC,CAAC;AA6BvC,MAAa,KAAM,SAAQ,KAAK;IAW9B,YAAY,IAAqB;QAC/B,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,8BAAyB,GAA0B,OAAO,CAAC;QAIlE,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QAED,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC;QAC5B,IAAI,CAAC,gBAAgB;YACjB,cAAc,CAAC,IAAI,CAAC,gBAAgB,IAAI,IAAI,CAAC,yBAAyB,CAAC,CAAC;QAC5E,IAAI,CAAC,gBAAgB,GAAG,cAAc,CAAC,IAAI,CAAC,gBAAgB,CAAC,CAAC;QAC9D,IAAI,CAAC,eAAe,GAAG,aAAa,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;QAC3D,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC;SACxB;aAAM,IAAI,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE;YACzC,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,UAAU,CAAC;SACnC;aAAM,IAAI,OAAO,IAAI,CAAC,UAAU,KAAK,QAAQ,EAAE;YAC9C,IAAI,CAAC,UAAU,GAAG,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;SACrC;aAAM;YACL,MAAM,IAAI,UAAU,CAChB,6DAA6D;gBAC7D,WAAW,IAAI,CAAC,UAAU,EAAE,CAAC,CAAC;SACnC;IACH,CAAC;IAEQ,KAAK,CAAC,UAAyB;QACtC,UAAU,GAAG,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAC5C,MAAM,UAAU,GAAU,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAC9C,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,KAAK,MAAM,CAAC,IAAI,IAAI,CAAC,UAAU,EAAE;gBAC/B,UAAU,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;aACvB;SACF;QACD,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,SAAS,CACvB,OAAO,EAAE,UAAU,EAAE,SAAS,EAAE,IAAI,CAAC,gBAAgB,EACrD,IAAI,CAAC,gBAAgB,EAAE,IAAI,EAAE,IAAI,CAAC,eAAe,CAAC,CAAC;QACvD,kBAAkB;QAClB,MAAM,IAAI,GAA6B,EAAE,CAAC;QAC1C,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;gBAC1C,IAAI,CAAC,CAAC,CAAC,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC;aACzB;SACF;QACD,IAAI,CAAC,SAAS,GAAG,CAAC,IAAI,SAAS,CAAC;gBAC9B,IAAI,EAAE,UAAU,CAAC,MAAM;gBACvB,IAAI;aACL,CAAC,CAAC,CAAC;QACJ,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC;IACpB,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACrC,OAAO,KAAK,CAAC,MAAM,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,EAAE,CAAC,CAAC;IAC1C,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B;YACvC,gBAAgB,EAAE,oBAAoB,CAAC,IAAI,CAAC,gBAAgB,CAAC;YAC7D,gBAAgB,EAAE,oBAAoB,CAAC,IAAI,CAAC,gBAAgB,CAAC;YAC7D,eAAe,EAAE,mBAAmB,CAAC,IAAI,CAAC,eAAe,CAAC;YAC1D,UAAU,EAAE,IAAI,CAAC,UAAU;SAC5B,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA1ED,kBAAkB;AACX,eAAS,GAAG,OAAO,AAAV,CAAW;SAFhB,KAAK;AA6ElB,aAAa,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC;AASnC,MAAa,GAAI,SAAQ,KAAK;IAO5B,YAAY,IAAmB;QAC7B,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,kBAAa,GAAG,GAAG,CAAC;QAI3B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QAED,IAAI,IAAI,CAAC,KAAK,IAAI,IAAI,IAAI,IAAI,CAAC,KAAK,KAAK,IAAI,CAAC,aAAa,EAAE;YAC3D,MAAM,IAAI,mBAAmB,CACzB,4BAA4B,IAAI,CAAC,KAAK,4BAA4B;gBAClE,gBAAgB,CAAC,CAAC;SACvB;QAED,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,aAAa,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC;IACpE,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,CAAC,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACtC,OAAO,GAAG,CAAC,CAAC,CAAC,CAAC;IAChB,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,KAAK,EAAE,IAAI,CAAC,KAAK,EAAC,CAAC;QAC7D,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AAnCD,kBAAkB;AACX,aAAS,GAAG,KAAK,AAAR,CAAS;SAFd,GAAG;AAsChB,aAAa,CAAC,aAAa,CAAC,GAAG,CAAC,CAAC;AASjC,MAAa,eAAgB,SAAQ,KAAK;IAOxC,YAAY,IAA+B;QACzC,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,kBAAa,GAAG,GAAG,CAAC;QAI3B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QAED,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,aAAa,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC;IACpE,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,CAAC,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACtC,OAAO,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;IACzD,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,KAAK,EAAE,IAAI,CAAC,KAAK,EAAC,CAAC;QAC7D,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA7BD,kBAAkB;AACX,yBAAS,GAAG,iBAAiB,AAApB,CAAqB;SAF1B,eAAe;AAgC5B,aAAa,CAAC,aAAa,CAAC,eAAe,CAAC,CAAC;AAU7C,MAAa,OAAQ,SAAQ,KAAK;IAOhC,YAAY,IAAuB;QACjC,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,iBAAY,GAAG,GAAG,CAAC;QAI1B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,IAAI,CAAC,OAAO,GAAG,IAAI,iBAAiB,EAAE,CAAC,KAAK,CAAC;QAC7C,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,YAAY,CAAC,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC;IAChE,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,CAAC,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACtC,OAAO,IAAI,CAAC,OAAO,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC;IACpC,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,IAAI,EAAE,IAAI,CAAC,IAAI,EAAC,CAAC;QAC3D,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA7BD,kBAAkB;AACX,iBAAS,GAAG,SAAS,AAAZ,CAAa;SAFlB,OAAO;AAgCpB,aAAa,CAAC,aAAa,CAAC,OAAO,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/**\n *  Advanced activation layers.\n */\n\nimport {cast, clipByValue, elu, greater, leakyRelu, mul, prelu, relu, serialization, Tensor} from '@tensorflow/tfjs-core';\n\nimport {Softmax as softmaxActivation} from '../activations';\nimport {Constraint, getConstraint, serializeConstraint} from '../constraints';\nimport {InputSpec, Layer, LayerArgs} from '../engine/topology';\nimport {NotImplementedError, ValueError} from '../errors';\nimport {getInitializer, Initializer, InitializerIdentifier, serializeInitializer} from '../initializers';\nimport {Shape} from '../keras_format/common';\nimport {getRegularizer, Regularizer, serializeRegularizer} from '../regularizers';\nimport {Kwargs} from '../types';\nimport {getExactlyOneShape, getExactlyOneTensor} from '../utils/types_utils';\nimport {LayerVariable} from '../variables';\n\nexport declare interface ReLULayerArgs extends LayerArgs {\n  /**\n   * Float, the maximum output value.\n   */\n  maxValue?: number;\n}\n\nexport class ReLU extends Layer {\n  /** @nocollapse */\n  static className = 'ReLU';\n  maxValue: number;\n\n  constructor(args?: ReLULayerArgs) {\n    super(args == null ? {} : args);\n    this.supportsMasking = true;\n    if (args != null) {\n      this.maxValue = args.maxValue;\n    }\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    inputs = getExactlyOneTensor(inputs);\n    let output = relu(inputs);\n    if (this.maxValue != null) {\n      output = clipByValue(output, 0, this.maxValue);\n    }\n    return output;\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {maxValue: this.maxValue};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(ReLU);\n\nexport declare interface LeakyReLULayerArgs extends LayerArgs {\n  /**\n   * Float `>= 0`. Negative slope coefficient. Defaults to `0.3`.\n   */\n  alpha?: number;\n}\n\nexport class LeakyReLU extends Layer {\n  /** @nocollapse */\n  static className = 'LeakyReLU';\n  readonly alpha: number;\n\n  readonly DEFAULT_ALPHA = 0.3;\n\n  constructor(args?: LeakyReLULayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n    this.alpha = args.alpha == null ? this.DEFAULT_ALPHA : args.alpha;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    const x = getExactlyOneTensor(inputs);\n    return leakyRelu(x, this.alpha);\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {alpha: this.alpha};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(LeakyReLU);\n\nexport declare interface PReLULayerArgs extends LayerArgs {\n  /**\n   * Initializer for the learnable alpha.\n   */\n  alphaInitializer?: Initializer|InitializerIdentifier;\n\n  /**\n   * Regularizer for the learnable alpha.\n   */\n  alphaRegularizer?: Regularizer;\n\n  /**\n   * Constraint for the learnable alpha.\n   */\n  alphaConstraint?: Constraint;\n\n  /**\n   * The axes along which to share learnable parameters for the activation\n   * function. For example, if the incoming feature maps are from a 2D\n   * convolution with output shape `[numExamples, height, width, channels]`,\n   * and you wish to share parameters across space (height and width) so that\n   * each filter channels has only one set of parameters, set\n   * `shared_axes: [1, 2]`.\n   */\n  sharedAxes?: number|number[];\n}\n\nexport class PReLU extends Layer {\n  /** @nocollapse */\n  static className = 'PReLU';\n  private readonly alphaInitializer: Initializer;\n  private readonly alphaRegularizer: Regularizer;\n  private readonly alphaConstraint: Constraint;\n  private readonly sharedAxes: number[];\n  private alpha: LayerVariable;\n\n  readonly DEFAULT_ALPHA_INITIALIZER: InitializerIdentifier = 'zeros';\n\n  constructor(args?: PReLULayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n\n    this.supportsMasking = true;\n    this.alphaInitializer =\n        getInitializer(args.alphaInitializer || this.DEFAULT_ALPHA_INITIALIZER);\n    this.alphaRegularizer = getRegularizer(args.alphaRegularizer);\n    this.alphaConstraint = getConstraint(args.alphaConstraint);\n    if (args.sharedAxes == null) {\n      this.sharedAxes = null;\n    } else if (Array.isArray(args.sharedAxes)) {\n      this.sharedAxes = args.sharedAxes;\n    } else if (typeof args.sharedAxes === 'number') {\n      this.sharedAxes = [args.sharedAxes];\n    } else {\n      throw new ValueError(\n          `Expected sharedAxes to be a number or an array of numbers, ` +\n          `but got ${args.sharedAxes}`);\n    }\n  }\n\n  override build(inputShape: Shape|Shape[]) {\n    inputShape = getExactlyOneShape(inputShape);\n    const paramShape: Shape = inputShape.slice(1);\n    if (this.sharedAxes != null) {\n      for (const i of this.sharedAxes) {\n        paramShape[i - 1] = 1;\n      }\n    }\n    this.alpha = this.addWeight(\n        'alpha', paramShape, 'float32', this.alphaInitializer,\n        this.alphaRegularizer, true, this.alphaConstraint);\n    // Set input spec.\n    const axes: {[axis: number]: number} = {};\n    if (this.sharedAxes != null) {\n      for (let i = 1; i < inputShape.length; ++i) {\n        axes[i] = inputShape[i];\n      }\n    }\n    this.inputSpec = [new InputSpec({\n      ndim: inputShape.length,\n      axes,\n    })];\n    this.built = true;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    inputs = getExactlyOneTensor(inputs);\n    return prelu(inputs, this.alpha.read());\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {\n      alphaInitializer: serializeInitializer(this.alphaInitializer),\n      alphaRegularizer: serializeRegularizer(this.alphaRegularizer),\n      alphaConstraint: serializeConstraint(this.alphaConstraint),\n      sharedAxes: this.sharedAxes\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(PReLU);\n\nexport declare interface ELULayerArgs extends LayerArgs {\n  /**\n   * Float `>= 0`. Negative slope coefficient. Defaults to `1.0`.\n   */\n  alpha?: number;\n}\n\nexport class ELU extends Layer {\n  /** @nocollapse */\n  static className = 'ELU';\n  readonly alpha: number;\n\n  readonly DEFAULT_ALPHA = 1.0;\n\n  constructor(args?: ELULayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n\n    if (args.alpha != null && args.alpha !== this.DEFAULT_ALPHA) {\n      throw new NotImplementedError(\n          `Non-default alpha value (${args.alpha}) is not supported by the ` +\n          `ELU layer yet.`);\n    }\n\n    this.alpha = args.alpha == null ? this.DEFAULT_ALPHA : args.alpha;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    const x = getExactlyOneTensor(inputs);\n    return elu(x);\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {alpha: this.alpha};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(ELU);\n\nexport declare interface ThresholdedReLULayerArgs extends LayerArgs {\n  /**\n   * Float >= 0. Threshold location of activation.\n   */\n  theta?: number;\n}\n\nexport class ThresholdedReLU extends Layer {\n  /** @nocollapse */\n  static className = 'ThresholdedReLU';\n  readonly theta: number;\n\n  readonly DEFAULT_THETA = 1.0;\n\n  constructor(args?: ThresholdedReLULayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n\n    this.theta = args.theta == null ? this.DEFAULT_THETA : args.theta;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    const x = getExactlyOneTensor(inputs);\n    return mul(x, cast(greater(x, this.theta), 'float32'));\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {theta: this.theta};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(ThresholdedReLU);\n\nexport declare interface SoftmaxLayerArgs extends LayerArgs {\n  /**\n   * Integer, axis along which the softmax normalization is applied.\n   * Defaults to `-1` (i.e., the last axis).\n   */\n  axis?: number;\n}\n\nexport class Softmax extends Layer {\n  /** @nocollapse */\n  static className = 'Softmax';\n  readonly axis: number;\n  readonly softmax: (t: Tensor, a?: number) => Tensor;\n  readonly DEFAULT_AXIS = 1.0;\n\n  constructor(args?: SoftmaxLayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n    this.softmax = new softmaxActivation().apply;\n    this.axis = args.axis == null ? this.DEFAULT_AXIS : args.axis;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    const x = getExactlyOneTensor(inputs);\n    return this.softmax(x, this.axis);\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {axis: this.axis};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(Softmax);\n"]}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"advanced_activations.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/layers/advanced_activations.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH;;GAEG;AAEH,OAAO,EAAC,GAAG,EAAE,IAAI,EAAE,WAAW,EAAE,GAAG,EAAE,GAAG,EAAE,OAAO,EAAE,SAAS,EAAE,SAAS,EAAE,GAAG,EAAE,IAAI,EAAE,KAAK,EAAE,IAAI,EAAE,MAAM,EAAE,aAAa,EAAE,GAAG,EAAU,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAExK,OAAO,EAAC,OAAO,IAAI,iBAAiB,EAAC,MAAM,gBAAgB,CAAC;AAC5D,OAAO,EAAa,aAAa,EAAE,mBAAmB,EAAC,MAAM,gBAAgB,CAAC;AAC9E,OAAO,EAAC,SAAS,EAAE,KAAK,EAAY,MAAM,oBAAoB,CAAC;AAC/D,OAAO,EAAC,mBAAmB,EAAE,UAAU,EAAC,MAAM,WAAW,CAAC;AAC1D,OAAO,EAAC,cAAc,EAAsC,oBAAoB,EAAC,MAAM,iBAAiB,CAAC;AAEzG,OAAO,EAAC,cAAc,EAAe,oBAAoB,EAAC,MAAM,iBAAiB,CAAC;AAElF,OAAO,EAAC,kBAAkB,EAAE,mBAAmB,EAAC,MAAM,sBAAsB,CAAC;AAU7E,MAAa,IAAK,SAAQ,KAAK;IAK7B,YAAY,IAAoB;QAC9B,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAChC,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC;QAC5B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;SAC/B;IACH,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACrC,IAAI,MAAM,GAAG,IAAI,CAAC,MAAM,CAAC,CAAC;QAC1B,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;YACzB,MAAM,GAAG,WAAW,CAAC,MAAM,EAAE,CAAC,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC;SAChD;QACD,OAAO,MAAM,CAAC;IAChB,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAC,CAAC;QACnE,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA9BD,kBAAkB;AACX,cAAS,GAAG,MAAM,CAAC;SAFf,IAAI;AAiCjB,aAAa,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;AASlC,MAAa,SAAU,SAAQ,KAAK;IAOlC,YAAY,IAAyB;QACnC,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,kBAAa,GAAG,GAAG,CAAC;QAI3B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,aAAa,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC;IACpE,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,CAAC,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACtC,OAAO,SAAS,CAAC,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;IAClC,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,KAAK,EAAE,IAAI,CAAC,KAAK,EAAC,CAAC;QAC7D,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA5BD,kBAAkB;AACX,mBAAS,GAAG,WAAW,AAAd,CAAe;SAFpB,SAAS;AA+BtB,aAAa,CAAC,aAAa,CAAC,SAAS,CAAC,CAAC;AA6BvC,MAAa,KAAM,SAAQ,KAAK;IAW9B,YAAY,IAAqB;QAC/B,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,8BAAyB,GAA0B,OAAO,CAAC;QAIlE,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QAED,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC;QAC5B,IAAI,CAAC,gBAAgB;YACjB,cAAc,CAAC,IAAI,CAAC,gBAAgB,IAAI,IAAI,CAAC,yBAAyB,CAAC,CAAC;QAC5E,IAAI,CAAC,gBAAgB,GAAG,cAAc,CAAC,IAAI,CAAC,gBAAgB,CAAC,CAAC;QAC9D,IAAI,CAAC,eAAe,GAAG,aAAa,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;QAC3D,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC;SACxB;aAAM,IAAI,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,UAAU,CAAC,EAAE;YACzC,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,UAAU,CAAC;SACnC;aAAM,IAAI,OAAO,IAAI,CAAC,UAAU,KAAK,QAAQ,EAAE;YAC9C,IAAI,CAAC,UAAU,GAAG,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;SACrC;aAAM;YACL,MAAM,IAAI,UAAU,CAChB,6DAA6D;gBAC7D,WAAW,IAAI,CAAC,UAAU,EAAE,CAAC,CAAC;SACnC;IACH,CAAC;IAEQ,KAAK,CAAC,UAAyB;QACtC,UAAU,GAAG,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAC5C,MAAM,UAAU,GAAU,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;QAC9C,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,KAAK,MAAM,CAAC,IAAI,IAAI,CAAC,UAAU,EAAE;gBAC/B,UAAU,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;aACvB;SACF;QACD,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,SAAS,CACvB,OAAO,EAAE,UAAU,EAAE,SAAS,EAAE,IAAI,CAAC,gBAAgB,EACrD,IAAI,CAAC,gBAAgB,EAAE,IAAI,EAAE,IAAI,CAAC,eAAe,CAAC,CAAC;QACvD,kBAAkB;QAClB,MAAM,IAAI,GAA6B,EAAE,CAAC;QAC1C,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;YAC3B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;gBAC1C,IAAI,CAAC,CAAC,CAAC,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC;aACzB;SACF;QACD,IAAI,CAAC,SAAS,GAAG,CAAC,IAAI,SAAS,CAAC;gBAC9B,IAAI,EAAE,UAAU,CAAC,MAAM;gBACvB,IAAI;aACL,CAAC,CAAC,CAAC;QACJ,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC;IACpB,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACrC,OAAO,KAAK,CAAC,MAAM,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,EAAE,CAAC,CAAC;IAC1C,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B;YACvC,gBAAgB,EAAE,oBAAoB,CAAC,IAAI,CAAC,gBAAgB,CAAC;YAC7D,gBAAgB,EAAE,oBAAoB,CAAC,IAAI,CAAC,gBAAgB,CAAC;YAC7D,eAAe,EAAE,mBAAmB,CAAC,IAAI,CAAC,eAAe,CAAC;YAC1D,UAAU,EAAE,IAAI,CAAC,UAAU;SAC5B,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA1ED,kBAAkB;AACX,eAAS,GAAG,OAAO,AAAV,CAAW;SAFhB,KAAK;AA6ElB,aAAa,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC;AASnC,MAAa,GAAI,SAAQ,KAAK;IAO5B,YAAY,IAAmB;QAC7B,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,kBAAa,GAAG,GAAG,CAAC;QAI3B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QAED,IAAI,IAAI,CAAC,KAAK,IAAI,IAAI,IAAI,IAAI,CAAC,KAAK,KAAK,IAAI,CAAC,aAAa,EAAE;YAC3D,MAAM,IAAI,mBAAmB,CACzB,4BAA4B,IAAI,CAAC,KAAK,4BAA4B;gBAClE,gBAAgB,CAAC,CAAC;SACvB;QAED,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,aAAa,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC;IACpE,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,CAAC,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACtC,OAAO,GAAG,CAAC,CAAC,CAAC,CAAC;IAChB,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,KAAK,EAAE,IAAI,CAAC,KAAK,EAAC,CAAC;QAC7D,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AAnCD,kBAAkB;AACX,aAAS,GAAG,KAAK,AAAR,CAAS;SAFd,GAAG;AAsChB,aAAa,CAAC,aAAa,CAAC,GAAG,CAAC,CAAC;AASjC,MAAa,eAAgB,SAAQ,KAAK;IAOxC,YAAY,IAA+B;QACzC,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,kBAAa,GAAG,GAAG,CAAC;QAI3B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QAED,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,KAAK,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,aAAa,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC;IACpE,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,MAAM,CAAC,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;QACtC,OAAO,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;IACzD,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,KAAK,EAAE,IAAI,CAAC,KAAK,EAAC,CAAC;QAC7D,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA7BD,kBAAkB;AACX,yBAAS,GAAG,iBAAiB,AAApB,CAAqB;SAF1B,eAAe;AAgC5B,aAAa,CAAC,aAAa,CAAC,eAAe,CAAC,CAAC;AAU7C,MAAa,OAAQ,SAAQ,KAAK;IAOhC,YAAY,IAAuB;QACjC,KAAK,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;QAHzB,iBAAY,GAAG,GAAG,CAAC;QAI1B,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,IAAI,GAAG,EAAE,CAAC;SACX;QACD,IAAI,CAAC,OAAO,GAAG,IAAI,iBAAiB,EAAE,CAAC,KAAK,CAAC;QAC7C,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,CAAC,IAAI,CAAC,YAAY,CAAC,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC;IAChE,CAAC;IAEQ,IAAI,CAAC,MAAuB,EAAE,MAAc;QACnD,mEAAmE;QACnE,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,IAAI,CAAC,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC;YACpC,MAAM,IAAI,GAAG,MAAM,CAAC,MAAM,CAAW,CAAC;YACtC,IAAI,IAAI,IAAI,IAAI,EAAE;gBAChB,qEAAqE;gBACrE,kEAAkE;gBAClE,8DAA8D;gBAC9D,MAAM,KAAK,GACT,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;gBAE7D,oEAAoE;gBACpE,sDAAsD;gBACtD,CAAC,GAAG,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,CAAC;aACnB;YACD,IAAI,IAAI,CAAC,IAAI,YAAY,KAAK,EAAE;gBAC9B,IAAI,IAAI,CAAC,IAAI,CAAC,MAAM,GAAG,CAAC,EAAE;oBACxB,OAAO,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC;iBACnD;qBAAM;oBACL,OAAO,IAAI,CAAC,OAAO,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;iBACtC;aACF;YACD,OAAO,IAAI,CAAC,OAAO,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC;QACpC,CAAC,CAAC,CAAC;IACL,CAAC;IAEQ,kBAAkB,CAAC,UAAyB;QACnD,OAAO,UAAU,CAAC;IACpB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B,EAAC,IAAI,EAAE,IAAI,CAAC,IAAI,EAAC,CAAC;QAC3D,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AAnDD,kBAAkB;AACX,iBAAS,GAAG,SAAS,AAAZ,CAAa;SAFlB,OAAO;AAsDpB,aAAa,CAAC,aAAa,CAAC,OAAO,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/**\n *  Advanced activation layers.\n */\n\nimport {add, cast, clipByValue, elu, exp, greater, leakyRelu, logSumExp, mul, ones, prelu, relu, scalar, serialization, sub, Tensor, tidy} from '@tensorflow/tfjs-core';\n\nimport {Softmax as softmaxActivation} from '../activations';\nimport {Constraint, getConstraint, serializeConstraint} from '../constraints';\nimport {InputSpec, Layer, LayerArgs} from '../engine/topology';\nimport {NotImplementedError, ValueError} from '../errors';\nimport {getInitializer, Initializer, InitializerIdentifier, serializeInitializer} from '../initializers';\nimport {Shape} from '../keras_format/common';\nimport {getRegularizer, Regularizer, serializeRegularizer} from '../regularizers';\nimport {Kwargs} from '../types';\nimport {getExactlyOneShape, getExactlyOneTensor} from '../utils/types_utils';\nimport {LayerVariable} from '../variables';\n\nexport declare interface ReLULayerArgs extends LayerArgs {\n  /**\n   * Float, the maximum output value.\n   */\n  maxValue?: number;\n}\n\nexport class ReLU extends Layer {\n  /** @nocollapse */\n  static className = 'ReLU';\n  maxValue: number;\n\n  constructor(args?: ReLULayerArgs) {\n    super(args == null ? {} : args);\n    this.supportsMasking = true;\n    if (args != null) {\n      this.maxValue = args.maxValue;\n    }\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    inputs = getExactlyOneTensor(inputs);\n    let output = relu(inputs);\n    if (this.maxValue != null) {\n      output = clipByValue(output, 0, this.maxValue);\n    }\n    return output;\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {maxValue: this.maxValue};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(ReLU);\n\nexport declare interface LeakyReLULayerArgs extends LayerArgs {\n  /**\n   * Float `>= 0`. Negative slope coefficient. Defaults to `0.3`.\n   */\n  alpha?: number;\n}\n\nexport class LeakyReLU extends Layer {\n  /** @nocollapse */\n  static className = 'LeakyReLU';\n  readonly alpha: number;\n\n  readonly DEFAULT_ALPHA = 0.3;\n\n  constructor(args?: LeakyReLULayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n    this.alpha = args.alpha == null ? this.DEFAULT_ALPHA : args.alpha;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    const x = getExactlyOneTensor(inputs);\n    return leakyRelu(x, this.alpha);\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {alpha: this.alpha};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(LeakyReLU);\n\nexport declare interface PReLULayerArgs extends LayerArgs {\n  /**\n   * Initializer for the learnable alpha.\n   */\n  alphaInitializer?: Initializer|InitializerIdentifier;\n\n  /**\n   * Regularizer for the learnable alpha.\n   */\n  alphaRegularizer?: Regularizer;\n\n  /**\n   * Constraint for the learnable alpha.\n   */\n  alphaConstraint?: Constraint;\n\n  /**\n   * The axes along which to share learnable parameters for the activation\n   * function. For example, if the incoming feature maps are from a 2D\n   * convolution with output shape `[numExamples, height, width, channels]`,\n   * and you wish to share parameters across space (height and width) so that\n   * each filter channels has only one set of parameters, set\n   * `shared_axes: [1, 2]`.\n   */\n  sharedAxes?: number|number[];\n}\n\nexport class PReLU extends Layer {\n  /** @nocollapse */\n  static className = 'PReLU';\n  private readonly alphaInitializer: Initializer;\n  private readonly alphaRegularizer: Regularizer;\n  private readonly alphaConstraint: Constraint;\n  private readonly sharedAxes: number[];\n  private alpha: LayerVariable;\n\n  readonly DEFAULT_ALPHA_INITIALIZER: InitializerIdentifier = 'zeros';\n\n  constructor(args?: PReLULayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n\n    this.supportsMasking = true;\n    this.alphaInitializer =\n        getInitializer(args.alphaInitializer || this.DEFAULT_ALPHA_INITIALIZER);\n    this.alphaRegularizer = getRegularizer(args.alphaRegularizer);\n    this.alphaConstraint = getConstraint(args.alphaConstraint);\n    if (args.sharedAxes == null) {\n      this.sharedAxes = null;\n    } else if (Array.isArray(args.sharedAxes)) {\n      this.sharedAxes = args.sharedAxes;\n    } else if (typeof args.sharedAxes === 'number') {\n      this.sharedAxes = [args.sharedAxes];\n    } else {\n      throw new ValueError(\n          `Expected sharedAxes to be a number or an array of numbers, ` +\n          `but got ${args.sharedAxes}`);\n    }\n  }\n\n  override build(inputShape: Shape|Shape[]) {\n    inputShape = getExactlyOneShape(inputShape);\n    const paramShape: Shape = inputShape.slice(1);\n    if (this.sharedAxes != null) {\n      for (const i of this.sharedAxes) {\n        paramShape[i - 1] = 1;\n      }\n    }\n    this.alpha = this.addWeight(\n        'alpha', paramShape, 'float32', this.alphaInitializer,\n        this.alphaRegularizer, true, this.alphaConstraint);\n    // Set input spec.\n    const axes: {[axis: number]: number} = {};\n    if (this.sharedAxes != null) {\n      for (let i = 1; i < inputShape.length; ++i) {\n        axes[i] = inputShape[i];\n      }\n    }\n    this.inputSpec = [new InputSpec({\n      ndim: inputShape.length,\n      axes,\n    })];\n    this.built = true;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    inputs = getExactlyOneTensor(inputs);\n    return prelu(inputs, this.alpha.read());\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {\n      alphaInitializer: serializeInitializer(this.alphaInitializer),\n      alphaRegularizer: serializeRegularizer(this.alphaRegularizer),\n      alphaConstraint: serializeConstraint(this.alphaConstraint),\n      sharedAxes: this.sharedAxes\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(PReLU);\n\nexport declare interface ELULayerArgs extends LayerArgs {\n  /**\n   * Float `>= 0`. Negative slope coefficient. Defaults to `1.0`.\n   */\n  alpha?: number;\n}\n\nexport class ELU extends Layer {\n  /** @nocollapse */\n  static className = 'ELU';\n  readonly alpha: number;\n\n  readonly DEFAULT_ALPHA = 1.0;\n\n  constructor(args?: ELULayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n\n    if (args.alpha != null && args.alpha !== this.DEFAULT_ALPHA) {\n      throw new NotImplementedError(\n          `Non-default alpha value (${args.alpha}) is not supported by the ` +\n          `ELU layer yet.`);\n    }\n\n    this.alpha = args.alpha == null ? this.DEFAULT_ALPHA : args.alpha;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    const x = getExactlyOneTensor(inputs);\n    return elu(x);\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {alpha: this.alpha};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(ELU);\n\nexport declare interface ThresholdedReLULayerArgs extends LayerArgs {\n  /**\n   * Float >= 0. Threshold location of activation.\n   */\n  theta?: number;\n}\n\nexport class ThresholdedReLU extends Layer {\n  /** @nocollapse */\n  static className = 'ThresholdedReLU';\n  readonly theta: number;\n\n  readonly DEFAULT_THETA = 1.0;\n\n  constructor(args?: ThresholdedReLULayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n\n    this.theta = args.theta == null ? this.DEFAULT_THETA : args.theta;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    const x = getExactlyOneTensor(inputs);\n    return mul(x, cast(greater(x, this.theta), 'float32'));\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {theta: this.theta};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(ThresholdedReLU);\n\nexport declare interface SoftmaxLayerArgs extends LayerArgs {\n  /**\n   * Integer, axis along which the softmax normalization is applied.\n   * Defaults to `-1` (i.e., the last axis).\n   */\n  axis?: number|number[];\n}\n\nexport class Softmax extends Layer {\n  /** @nocollapse */\n  static className = 'Softmax';\n  readonly axis: number|number[];\n  readonly softmax: (t: Tensor, a?: number) => Tensor;\n  readonly DEFAULT_AXIS = 1.0;\n\n  constructor(args?: SoftmaxLayerArgs) {\n    super(args == null ? {} : args);\n    if (args == null) {\n      args = {};\n    }\n    this.softmax = new softmaxActivation().apply;\n    this.axis = args.axis == null ? this.DEFAULT_AXIS : args.axis;\n  }\n\n  override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {\n    // TODO(pforderique): Add tests for when `this.axis` is a number[].\n    return tidy(() => {\n      let x = getExactlyOneTensor(inputs);\n      const mask = kwargs['mask'] as Tensor;\n      if (mask != null) {\n        // Since mask is 1.0 for positions we want to keep and 0.0 for masked\n        // positions, this operation will create a tensor which is 0.0 for\n        // positions we want to attend and -1e.9 for masked positions.\n        const adder =\n          mul(sub(ones(x.shape), cast(mask, x.dtype)), scalar(-1e9));\n\n        // Since we are adding it to the raw scores before the softmax, this\n        // is effectively the same as removing these entirely.\n        x = add(x, adder);\n      }\n      if (this.axis instanceof Array) {\n        if (this.axis.length > 1) {\n          return exp(sub(x, logSumExp(x, this.axis, true)));\n        } else {\n          return this.softmax(x, this.axis[0]);\n        }\n      }\n      return this.softmax(x, this.axis);\n    });\n  }\n\n  override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {\n    return inputShape;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {axis: this.axis};\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(Softmax);\n"]}

@@ -21,3 +21,3 @@ /**

*/
import { Tensor, Tensor1D, Tensor2D } from '@tensorflow/tfjs-core';
import { Tensor } from '@tensorflow/tfjs-core';
import { MultiHeadAttention } from '../multihead_attention';

@@ -63,3 +63,3 @@ export declare interface CachedMultiHeadAttentionOptions {

*/
cacheUpdateIndex?: number | Tensor;
cacheUpdateIndex?: number;
}

@@ -97,7 +97,7 @@ /**

export declare class CachedMultiHeadAttention extends MultiHeadAttention {
call(query: Tensor, kwargs: CachedMultiHeadAttentionOptions): Tensor | Tensor2D;
call(query: Tensor, kwargs: CachedMultiHeadAttentionOptions): Tensor;
/**
* Exactly like `call` except also returns the updated cache.
*/
callAndReturnCache(query: Tensor, kwargs: CachedMultiHeadAttentionOptions): [Tensor1D | Tensor2D, Tensor1D | Tensor2D];
callAndReturnCache(query: Tensor, { value, key, attentionMask, cache, cacheUpdateIndex }: CachedMultiHeadAttentionOptions): [Tensor, Tensor];
}

@@ -21,5 +21,6 @@ /**

/* Original source: keras_nlp/layers/modeling/cached_multi_head_attention.py */
import { serialization } from '@tensorflow/tfjs-core';
import { cast, einsum, mul, reciprocal, serialization, sqrt, stack, tidy } from '@tensorflow/tfjs-core';
import { ValueError } from '../../../errors';
import { MultiHeadAttention } from '../multihead_attention';
import { NotImplementedError } from '../../../errors';
import { sliceUpdate } from '../utils';
/**

@@ -62,7 +63,52 @@ * MultiHeadAttention layer with cache support.

*/
callAndReturnCache(query, kwargs) {
throw new NotImplementedError(`Not implemented yet.`);
callAndReturnCache(query, { value, key, attentionMask, cache, cacheUpdateIndex }) {
return tidy(() => {
if (!this.builtFromSignature) {
this.buildFromSignature(query.shape, value.shape, key ? key.shape : null);
}
if (key == null) {
key = value;
}
query = this.queryDense.apply(query);
// If cache is not `null`, we will use the cache to compute the final key
// and value tensors. If `cacheUpdateIndex` is not `null`, we will first
// update the cache before use. To do this, we first call the
// `keyDense` and `valueDense` layers, and copy the outputs into the
// cache at the specified index. `cache = null` handles the training
// case, where we don't use the cache at all.
if (cache != null) {
const keyCache = cache.gather([0], 1).squeeze();
const valueCache = cache.gather([1], 1).squeeze();
if (cacheUpdateIndex == null) {
key = keyCache;
value = valueCache;
}
else {
const keyUpdate = this.keyDense.apply(key);
const valueUpdate = this.valueDense.apply(value);
const start = [0, cacheUpdateIndex, 0, 0];
key = sliceUpdate(keyCache, start, keyUpdate);
value = sliceUpdate(valueCache, start, valueUpdate);
cache = stack([key, value], 1);
}
}
else {
if (cacheUpdateIndex != null) {
throw new ValueError('`cacheUpdateIndex` should not be set if `cache` is `null`. ' +
`Received: cache=${cache}, cacheUpdateIndex=${cacheUpdateIndex}`);
}
key = this.keyDense.apply(key);
value = this.valueDense.apply(value);
}
query = mul(query, reciprocal(sqrt(cast(this.keyDim, query.dtype))));
let attentionScores = einsum(this.dotProductEquation, key, query);
attentionScores = this.maskedSoftmax(attentionScores, attentionMask);
attentionScores = this.dropoutLayer.apply(attentionScores);
let attentionOutput = einsum(this.combineEquation, attentionScores, value);
attentionOutput = this.outputDense.apply(attentionOutput);
return [attentionOutput, cache];
});
}
}
serialization.registerClass(CachedMultiHeadAttention);
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiY2FjaGVkX211bHRpaGVhZF9hdHRlbnRpb24uanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWxheWVycy9zcmMvbGF5ZXJzL25scC9tb2RlbGluZy9jYWNoZWRfbXVsdGloZWFkX2F0dGVudGlvbi50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSDs7R0FFRztBQUVILCtFQUErRTtBQUMvRSxPQUFPLEVBQThCLGFBQWEsRUFBRSxNQUFNLHVCQUF1QixDQUFDO0FBRWxGLE9BQU8sRUFBRSxrQkFBa0IsRUFBRSxNQUFNLHdCQUF3QixDQUFDO0FBQzVELE9BQU8sRUFBRSxtQkFBbUIsRUFBRSxNQUFNLGlCQUFpQixDQUFDO0FBaUR0RDs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7R0E2Qkc7QUFDSCxNQUFNLE9BQU8sd0JBQXlCLFNBQVEsa0JBQWtCO0lBRXJELElBQUksQ0FDWCxLQUFhLEVBQUUsTUFBdUM7UUFFdEQsT0FBTyxJQUFJLENBQUMsa0JBQWtCLENBQUMsS0FBSyxFQUFFLE1BQU0sQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQ25ELENBQUM7SUFFRDs7T0FFRztJQUNILGtCQUFrQixDQUNoQixLQUFhLEVBQUUsTUFBdUM7UUFFdEQsTUFBTSxJQUFJLG1CQUFtQixDQUFDLHNCQUFzQixDQUFDLENBQUM7SUFDeEQsQ0FBQztDQUNGO0FBQ0QsYUFBYSxDQUFDLGFBQWEsQ0FBQyx3QkFBd0IsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjMgR29vZ2xlIExMQy5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG4vKipcbiAqICBDYWNoZWQgTUhBIGxheWVyIGJhc2VkIG9uIGBNdWx0aUhlYWRBdHRlbnRpb25gLlxuICovXG5cbi8qIE9yaWdpbmFsIHNvdXJjZToga2VyYXNfbmxwL2xheWVycy9tb2RlbGluZy9jYWNoZWRfbXVsdGlfaGVhZF9hdHRlbnRpb24ucHkgKi9cbmltcG9ydCB7IFRlbnNvciwgVGVuc29yMUQsIFRlbnNvcjJELCBzZXJpYWxpemF0aW9uIH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHsgTXVsdGlIZWFkQXR0ZW50aW9uIH0gZnJvbSAnLi4vbXVsdGloZWFkX2F0dGVudGlvbic7XG5pbXBvcnQgeyBOb3RJbXBsZW1lbnRlZEVycm9yIH0gZnJvbSAnLi4vLi4vLi4vZXJyb3JzJztcblxuZXhwb3J0IGRlY2xhcmUgaW50ZXJmYWNlIENhY2hlZE11bHRpSGVhZEF0dGVudGlvbk9wdGlvbnMge1xuICAvKipcbiAgICogUXVlcnkgYFRlbnNvcmAgb2Ygc2hhcGUgYChCLCBULCBkaW0pYC5cbiAgICovXG5cbiAgLyoqXG4gICAqIFZhbHVlIGBUZW5zb3JgIG9mIHNoYXBlIGAoQiwgUyosIGRpbSlgLiBJZiBgY2FjaGVgIGlzIGBudWxsYCwgYFMqYFxuICAgKiBtdXN0IGVxdWFsIGBTYCBhbmQgbWF0Y2ggdGhlIHNoYXBlIG9mIGBhdHRlbnRpb25NYXNrYC4gSWYgYGNhY2hlYCBpc1xuICAgKiBub3QgYG51bGxgLCBgUypgIGNhbiBiZSBhbnkgbGVuZ3RoIGxlc3MgdGhhbiBgU2AsIGFuZCB0aGUgY29tcHV0ZWRcbiAgICogdmFsdWUgd2lsbCBiZSBzcGxpY2VkIGludG8gYGNhY2hlYCBhdCBgY2FjaGVVcGRhdGVJbmRleGAuXG4gICAqL1xuICB2YWx1ZTogVGVuc29yO1xuXG4gIC8qKlxuICAgKiBLZXkgYFRlbnNvcmAgb2Ygc2hhcGUgYChCLCBTKiwgZGltKWAuICBJZiBgY2FjaGVgIGlzIGBudWxsYCwgYFMqYCBtdXN0XG4gICAqIGVxdWFsIGBTYCBhbmQgbWF0Y2ggdGhlIHNoYXBlIG9mIGBhdHRlbnRpb25NYXNrYC4gSWYgYGNhY2hlYCBpcyBub3QgYG51bGxgLFxuICAgKiBgUypgIGNhbiBiZSBhbnkgbGVuZ3RoIGxlc3MgdGhhbiBgU2AsIGFuZCB0aGUgY29tcHV0ZWQgdmFsdWUgd2lsbCBiZVxuICAgKiBzcGxpY2VkIGludG8gYGNhY2hlYCBhdCBgY2FjaGVVcGRhdGVJbmRleGAuXG4gICAqL1xuICBrZXk/OiBUZW5zb3I7XG5cbiAgLyoqXG4gICAqIEEgYm9vbGVhbiBtYXNrIG9mIHNoYXBlIGAoQiwgVCwgUylgLiBgYXR0ZW50aW9uTWFza2AgcHJldmVudHNcbiAgICogYXR0ZW50aW9uIHRvIGNlcnRhaW4gcG9zaXRpb25zLiBUaGUgYm9vbGVhbiBtYXNrIHNwZWNpZmllcyB3aGljaFxuICAgKiBxdWVyeSBlbGVtZW50cyBjYW4gYXR0ZW5kIHRvIHdoaWNoIGtleSBlbGVtZW50cywgMSBpbmRpY2F0ZXNcbiAgICogYXR0ZW50aW9uIGFuZCAwIGluZGljYXRlcyBubyBhdHRlbnRpb24uIEJyb2FkY2FzdGluZyBjYW4gaGFwcGVuIGZvclxuICAgKiB0aGUgbWlzc2luZyBiYXRjaCBkaW1lbnNpb25zIGFuZCB0aGUgaGVhZCBkaW1lbnNpb24uXG4gICAqL1xuICBhdHRlbnRpb25NYXNrPzogVGVuc29yO1xuXG4gIC8qKlxuICAgKiBBIGRlbnNlIGZsb2F0IFRlbnNvci4gVGhlIGtleS92YWx1ZSBjYWNoZSwgb2Ygc2hhcGVcbiAgICogYFtCLCAyLCBTLCBudW1IZWFkcywga2V5RGltc11gLCB3aGVyZSBgU2AgbXVzdCBhZ3JlZSB3aXRoIHRoZVxuICAgKiBgYXR0ZW50aW9uTWFza2Agc2hhcGUuIFRoaXMgYXJndW1lbnQgaXMgaW50ZW5kZWQgZm9yIHVzZSBkdXJpbmdcbiAgICogZ2VuZXJhdGlvbiB0byBhdm9pZCByZWNvbXB1dGluZyBpbnRlcm1lZGlhdGUgc3RhdGUuXG4gICAqL1xuICBjYWNoZT86IFRlbnNvcjtcblxuICAvKipcbiAgICogSW50ZWdlciBvciBJbnRlZ2VyIGBUZW5zb3JgLiBUaGUgaW5kZXggYXQgd2hpY2ggdG8gdXBkYXRlIGBjYWNoZWBcbiAgICogKHVzdWFsbHkgdGhlIGluZGV4IG9mIHRoZSBjdXJyZW50IHRva2VuIGJlaW5nIHByb2Nlc3NlZCB3aGVuIHJ1bm5pbmdcbiAgICogZ2VuZXJhdGlvbikuIElmIGBjYWNoZVVwZGF0ZUluZGV4PW51bGxgIHdoaWxlIGBjYWNoZWAgaXMgc2V0LCB0aGUgY2FjaGVcbiAgICogd2lsbCBub3QgYmUgdXBkYXRlZC5cbiAgICovXG4gIGNhY2hlVXBkYXRlSW5kZXg/OiBudW1iZXJ8VGVuc29yO1xufVxuXG4vKipcbiAqIE11bHRpSGVhZEF0dGVudGlvbiBsYXllciB3aXRoIGNhY2hlIHN1cHBvcnQuXG4gKlxuICogVGhpcyBsYXllciBpcyBzdWl0YWJsZSBmb3IgdXNlIGluIGF1dG9yZWdyZXNzaXZlIGRlY29kaW5nLiBJdCBjYW4gYmUgdXNlXG4gKiB0byBjYWNoZSBkZWNvZGVyIHNlbGYtYXR0ZW50aW9uIGFuZCBjcm9zcy1hdHRlbnRpb24uIFRoZSBmb3J3YXJkIHBhc3NcbiAqIGNhbiBoYXBwZW4gaW4gb25lIG9mIHRocmVlIG1vZGVzOlxuICogLSBObyBjYWNoZSwgc2FtZSBhcyByZWd1bGFyIG11bHRpLWhlYWQgYXR0ZW50aW9uLlxuICogLSBTdGF0aWMgY2FjaGUgKGBjYWNoZVVwZGF0ZUluZGV4YCBpcyBOb25lKS4gSW4gdGhpcyBjYXNlLCB0aGVcbiAqICAgICBjYWNoZWQga2V5L3ZhbHVlIHByb2plY3Rpb25zIHdpbGwgYmUgdXNlZCBhbmQgdGhlIGlucHV0IHZhbHVlcyB3aWxsXG4gKiAgICAgYmUgaWdub3JlZC5cbiAqIC0gVXBkYXRlZCBjYWNoZSAoYGNhY2hlVXBkYXRlSW5kZXhgIGlzIG5vdCBOb25lKS4gSW4gdGhpcyBjYXNlLCBuZXdcbiAqICAgICBrZXkvdmFsdWUgcHJvamVjdGlvbnMgYXJlIGNvbXB1dGVkIHVzaW5nIHRoZSBpbnB1dCwgYW5kIHNwbGljZWQgaW50b1xuICogICAgIHRoZSBjYWNoZSBhdCB0aGUgc3BlY2lmaWVkIGluZGV4LlxuICpcbiAqIE5vdGUgdGhhdCBjYWNoaW5nIGlzIHVzZWZ1bCBvbmx5IGR1cmluZyBpbmZlcmVuY2UgYW5kIHNob3VsZCBub3QgYmUgdXNlZFxuICogZHVyaW5nIHRyYWluaW5nLlxuICpcbiAqIFdlIHVzZSB0aGUgbm90YXRpb24gYEJgLCBgVGAsIGBTYCBiZWxvdywgd2hlcmUgYEJgIGlzIHRoZSBiYXRjaCBkaW1lbnNpb24sXG4gKiBgVGAgaXMgdGhlIHRhcmdldCBzZXF1ZW5jZSBsZW5ndGgsIGFuZCBgU2AgaW4gdGhlIHNvdXJjZSBzZXF1ZW5jZSBsZW5ndGguXG4gKiBOb3RlIHRoYXQgZHVyaW5nIGdlbmVyYXRpdmUgZGVjb2RpbmcsIGBUYCBpcyB1c3VhbGx5IDEgKHlvdSBhcmVcbiAqIGdlbmVyYXRpbmcgYSB0YXJnZXQgc2VxdWVuY2Ugb2YgbGVuZ3RoIG9uZSB0byBwcmVkaWN0IHRoZSBuZXh0IHRva2VuKS5cbiAqXG4gKiBSZXR1cm5zOlxuICogICAgIEFuIGAoYXR0ZW50aW9uT3V0cHV0LCBjYWNoZSlgIHR1cGxlLiBgYXR0ZW50aW9uT3V0cHV0YCBpcyB0aGUgcmVzdWx0XG4gKiAgICAgb2YgdGhlIGNvbXB1dGF0aW9uLCBvZiBzaGFwZSBgKEIsIFQsIGRpbSlgLCB3aGVyZSBgVGAgaXMgZm9yIHRhcmdldFxuICogICAgIHNlcXVlbmNlIHNoYXBlcyBhbmQgYGRpbWAgaXMgdGhlIHF1ZXJ5IGlucHV0IGxhc3QgZGltZW5zaW9uIGlmXG4gKiAgICAgYG91dHB1dFNoYXBlYCBpcyBgbnVsbGAuIE90aGVyd2lzZSwgdGhlIG11bHRpLWhlYWQgb3V0cHV0cyBhcmVcbiAqICAgICBwcm9qZWN0ZWQgdG8gdGhlIHNoYXBlIHNwZWNpZmllZCBieSBgb3V0cHV0U2hhcGVgLiBgY2FjaGVgIGlzIHRoZVxuICogICAgIHVwZGF0ZWQgY2FjaGUuXG4gKi9cbmV4cG9ydCBjbGFzcyBDYWNoZWRNdWx0aUhlYWRBdHRlbnRpb24gZXh0ZW5kcyBNdWx0aUhlYWRBdHRlbnRpb24ge1xuXG4gIG92ZXJyaWRlIGNhbGwoXG4gICAgcXVlcnk6IFRlbnNvciwga3dhcmdzOiBDYWNoZWRNdWx0aUhlYWRBdHRlbnRpb25PcHRpb25zXG4gICk6IFRlbnNvcnxUZW5zb3IyRCB7XG4gICAgcmV0dXJuIHRoaXMuY2FsbEFuZFJldHVybkNhY2hlKHF1ZXJ5LCBrd2FyZ3MpWzBdO1xuICB9XG5cbiAgLyoqXG4gICAqIEV4YWN0bHkgbGlrZSBgY2FsbGAgZXhjZXB0IGFsc28gcmV0dXJucyB0aGUgdXBkYXRlZCBjYWNoZS5cbiAgICovXG4gIGNhbGxBbmRSZXR1cm5DYWNoZShcbiAgICBxdWVyeTogVGVuc29yLCBrd2FyZ3M6IENhY2hlZE11bHRpSGVhZEF0dGVudGlvbk9wdGlvbnNcbiAgKTogW1RlbnNvcjFEfFRlbnNvcjJELCBUZW5zb3IxRHxUZW5zb3IyRF0ge1xuICAgIHRocm93IG5ldyBOb3RJbXBsZW1lbnRlZEVycm9yKGBOb3QgaW1wbGVtZW50ZWQgeWV0LmApO1xuICB9XG59XG5zZXJpYWxpemF0aW9uLnJlZ2lzdGVyQ2xhc3MoQ2FjaGVkTXVsdGlIZWFkQXR0ZW50aW9uKTtcbiJdfQ==
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"cached_multihead_attention.js","sourceRoot":"","sources":["../../../../../../../../tfjs-layers/src/layers/nlp/modeling/cached_multihead_attention.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,+EAA+E;AAC/E,OAAO,EAAU,IAAI,EAAE,MAAM,EAAE,GAAG,EAAE,UAAU,EAAE,aAAa,EAAE,IAAI,EAAE,KAAK,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAEhH,OAAO,EAAE,UAAU,EAAE,MAAM,iBAAiB,CAAC;AAC7C,OAAO,EAAE,kBAAkB,EAAE,MAAM,wBAAwB,CAAC;AAC5D,OAAO,EAAE,WAAW,EAAE,MAAM,UAAU,CAAC;AAiDvC;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA6BG;AACH,MAAM,OAAO,wBAAyB,SAAQ,kBAAkB;IAErD,IAAI,CACX,KAAa,EAAE,MAAuC;QAEtD,OAAO,IAAI,CAAC,kBAAkB,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;IACnD,CAAC;IAED;;OAEG;IACH,kBAAkB,CAChB,KAAa,EACb,EACE,KAAK,EACL,GAAG,EACH,aAAa,EACb,KAAK,EACL,gBAAgB,EACiB;QAEnC,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,IAAI,CAAC,IAAI,CAAC,kBAAkB,EAAE;gBAC5B,IAAI,CAAC,kBAAkB,CACrB,KAAK,CAAC,KAAK,EAAE,KAAK,CAAC,KAAK,EAAE,GAAG,CAAC,CAAC,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;aACrD;YACD,IAAI,GAAG,IAAI,IAAI,EAAE;gBACf,GAAG,GAAG,KAAK,CAAC;aACb;YAED,KAAK,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAW,CAAC;YAC/C,yEAAyE;YACzE,wEAAwE;YACxE,6DAA6D;YAC7D,oEAAoE;YACpE,oEAAoE;YACpE,6CAA6C;YAC7C,IAAI,KAAK,IAAI,IAAI,EAAE;gBACjB,MAAM,QAAQ,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;gBAChD,MAAM,UAAU,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;gBAClD,IAAI,gBAAgB,IAAI,IAAI,EAAE;oBAC5B,GAAG,GAAG,QAAQ,CAAC;oBACf,KAAK,GAAG,UAAU,CAAC;iBACpB;qBAAM;oBACL,MAAM,SAAS,GAAG,IAAI,CAAC,QAAQ,CAAC,KAAK,CAAC,GAAG,CAAW,CAAC;oBACrD,MAAM,WAAW,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAW,CAAC;oBAC3D,MAAM,KAAK,GAAG,CAAC,CAAC,EAAE,gBAAgB,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;oBAC1C,GAAG,GAAG,WAAW,CAAC,QAAQ,EAAE,KAAK,EAAE,SAAS,CAAC,CAAC;oBAC9C,KAAK,GAAG,WAAW,CAAC,UAAU,EAAE,KAAK,EAAE,WAAW,CAAC,CAAC;oBACpD,KAAK,GAAG,KAAK,CAAC,CAAC,GAAG,EAAE,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC;iBAChC;aACF;iBAAM;gBACL,IAAI,gBAAgB,IAAI,IAAI,EAAE;oBAC5B,MAAM,IAAI,UAAU,CAClB,6DAA6D;wBAC7D,mBAAmB,KAAK,sBAAsB,gBAAgB,EAAE,CACjE,CAAC;iBACH;gBACD,GAAG,GAAG,IAAI,CAAC,QAAQ,CAAC,KAAK,CAAC,GAAG,CAAW,CAAC;gBACzC,KAAK,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAW,CAAC;aAChD;YAED,KAAK,GAAG,GAAG,CAAC,KAAK,EAAE,UAAU,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,MAAM,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;YACrE,IAAI,eAAe,GAAG,MAAM,CAAC,IAAI,CAAC,kBAAkB,EAAE,GAAG,EAAE,KAAK,CAAC,CAAC;YAClE,eAAe,GAAG,IAAI,CAAC,aAAa,CAAC,eAAe,EAAE,aAAa,CAAC,CAAC;YACrE,eAAe,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,eAAe,CAAW,CAAC;YAErE,IAAI,eAAe,GACjB,MAAM,CAAC,IAAI,CAAC,eAAe,EAAE,eAAe,EAAE,KAAK,CAAC,CAAC;YACvD,eAAe,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,eAAe,CAAW,CAAC;YAEpE,OAAO,CAAC,eAAe,EAAE,KAAK,CAAC,CAAC;QAClC,CAAC,CAAC,CAAC;IACL,CAAC;CACF;AACD,aAAa,CAAC,aAAa,CAAC,wBAAwB,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Cached MHA layer based on `MultiHeadAttention`.\n */\n\n/* Original source: keras_nlp/layers/modeling/cached_multi_head_attention.py */\nimport { Tensor, cast, einsum, mul, reciprocal, serialization, sqrt, stack, tidy } from '@tensorflow/tfjs-core';\n\nimport { ValueError } from '../../../errors';\nimport { MultiHeadAttention } from '../multihead_attention';\nimport { sliceUpdate } from '../utils';\n\nexport declare interface CachedMultiHeadAttentionOptions {\n  /**\n   * Query `Tensor` of shape `(B, T, dim)`.\n   */\n\n  /**\n   * Value `Tensor` of shape `(B, S*, dim)`. If `cache` is `null`, `S*`\n   * must equal `S` and match the shape of `attentionMask`. If `cache` is\n   * not `null`, `S*` can be any length less than `S`, and the computed\n   * value will be spliced into `cache` at `cacheUpdateIndex`.\n   */\n  value: Tensor;\n\n  /**\n   * Key `Tensor` of shape `(B, S*, dim)`.  If `cache` is `null`, `S*` must\n   * equal `S` and match the shape of `attentionMask`. If `cache` is not `null`,\n   * `S*` can be any length less than `S`, and the computed value will be\n   * spliced into `cache` at `cacheUpdateIndex`.\n   */\n  key?: Tensor;\n\n  /**\n   * A boolean mask of shape `(B, T, S)`. `attentionMask` prevents\n   * attention to certain positions. The boolean mask specifies which\n   * query elements can attend to which key elements, 1 indicates\n   * attention and 0 indicates no attention. Broadcasting can happen for\n   * the missing batch dimensions and the head dimension.\n   */\n  attentionMask?: Tensor;\n\n  /**\n   * A dense float Tensor. The key/value cache, of shape\n   * `[B, 2, S, numHeads, keyDims]`, where `S` must agree with the\n   * `attentionMask` shape. This argument is intended for use during\n   * generation to avoid recomputing intermediate state.\n   */\n  cache?: Tensor;\n\n  /**\n   * Integer or Integer `Tensor`. The index at which to update `cache`\n   * (usually the index of the current token being processed when running\n   * generation). If `cacheUpdateIndex=null` while `cache` is set, the cache\n   * will not be updated.\n   */\n  cacheUpdateIndex?: number;\n}\n\n/**\n * MultiHeadAttention layer with cache support.\n *\n * This layer is suitable for use in autoregressive decoding. It can be use\n * to cache decoder self-attention and cross-attention. The forward pass\n * can happen in one of three modes:\n * - No cache, same as regular multi-head attention.\n * - Static cache (`cacheUpdateIndex` is None). In this case, the\n *     cached key/value projections will be used and the input values will\n *     be ignored.\n * - Updated cache (`cacheUpdateIndex` is not None). In this case, new\n *     key/value projections are computed using the input, and spliced into\n *     the cache at the specified index.\n *\n * Note that caching is useful only during inference and should not be used\n * during training.\n *\n * We use the notation `B`, `T`, `S` below, where `B` is the batch dimension,\n * `T` is the target sequence length, and `S` in the source sequence length.\n * Note that during generative decoding, `T` is usually 1 (you are\n * generating a target sequence of length one to predict the next token).\n *\n * Returns:\n *     An `(attentionOutput, cache)` tuple. `attentionOutput` is the result\n *     of the computation, of shape `(B, T, dim)`, where `T` is for target\n *     sequence shapes and `dim` is the query input last dimension if\n *     `outputShape` is `null`. Otherwise, the multi-head outputs are\n *     projected to the shape specified by `outputShape`. `cache` is the\n *     updated cache.\n */\nexport class CachedMultiHeadAttention extends MultiHeadAttention {\n\n  override call(\n    query: Tensor, kwargs: CachedMultiHeadAttentionOptions\n  ): Tensor {\n    return this.callAndReturnCache(query, kwargs)[0];\n  }\n\n  /**\n   * Exactly like `call` except also returns the updated cache.\n   */\n  callAndReturnCache(\n    query: Tensor,\n    {\n      value,\n      key,\n      attentionMask,\n      cache,\n      cacheUpdateIndex\n    } : CachedMultiHeadAttentionOptions\n  ): [Tensor, Tensor] {\n    return tidy(() => {\n      if (!this.builtFromSignature) {\n        this.buildFromSignature(\n          query.shape, value.shape, key ? key.shape : null);\n      }\n      if (key == null) {\n        key = value;\n      }\n\n      query = this.queryDense.apply(query) as Tensor;\n      // If cache is not `null`, we will use the cache to compute the final key\n      // and value tensors. If `cacheUpdateIndex` is not `null`, we will first\n      // update the cache before use. To do this, we first call the\n      // `keyDense` and `valueDense` layers, and copy the outputs into the\n      // cache at the specified index. `cache = null` handles the training\n      // case, where we don't use the cache at all.\n      if (cache != null) {\n        const keyCache = cache.gather([0], 1).squeeze();\n        const valueCache = cache.gather([1], 1).squeeze();\n        if (cacheUpdateIndex == null) {\n          key = keyCache;\n          value = valueCache;\n        } else {\n          const keyUpdate = this.keyDense.apply(key) as Tensor;\n          const valueUpdate = this.valueDense.apply(value) as Tensor;\n          const start = [0, cacheUpdateIndex, 0, 0];\n          key = sliceUpdate(keyCache, start, keyUpdate);\n          value = sliceUpdate(valueCache, start, valueUpdate);\n          cache = stack([key, value], 1);\n        }\n      } else {\n        if (cacheUpdateIndex != null) {\n          throw new ValueError(\n            '`cacheUpdateIndex` should not be set if `cache` is `null`. ' +\n            `Received: cache=${cache}, cacheUpdateIndex=${cacheUpdateIndex}`\n          );\n        }\n        key = this.keyDense.apply(key) as Tensor;\n        value = this.valueDense.apply(value) as Tensor;\n      }\n\n      query = mul(query, reciprocal(sqrt(cast(this.keyDim, query.dtype))));\n      let attentionScores = einsum(this.dotProductEquation, key, query);\n      attentionScores = this.maskedSoftmax(attentionScores, attentionMask);\n      attentionScores = this.dropoutLayer.apply(attentionScores) as Tensor;\n\n      let attentionOutput =\n        einsum(this.combineEquation, attentionScores, value);\n      attentionOutput = this.outputDense.apply(attentionOutput) as Tensor;\n\n      return [attentionOutput, cache];\n    });\n  }\n}\nserialization.registerClass(CachedMultiHeadAttention);\n"]}

@@ -21,6 +21,7 @@ /**

*/
import { Tensor, Tensor1D, Tensor2D, serialization } from '@tensorflow/tfjs-core';
import { Tensor, serialization } from '@tensorflow/tfjs-core';
import { Shape } from '../../../keras_format/common';
import { Layer, LayerArgs } from '../../../engine/topology';
import { InitializerIdentifier } from '../../../initializers';
import { Initializer, InitializerIdentifier } from '../../../initializers';
import { LayerVariable } from '../../../variables';
export declare interface PositionEmbeddingArgs extends LayerArgs {

@@ -35,3 +36,3 @@ /**

*/
initializer?: InitializerIdentifier;
initializer?: Initializer | InitializerIdentifier;
}

@@ -81,6 +82,10 @@ export declare interface PositionEmbeddingOptions {

static readonly className = "PositionEmbedding";
private sequenceLength;
private initializer;
protected positionEmbeddings: LayerVariable;
constructor(args: PositionEmbeddingArgs);
getConfig(): serialization.ConfigDict;
build(inputShape: Shape | Shape[]): void;
call(inputs: Tensor | Tensor[], kwargs?: PositionEmbeddingOptions): Tensor1D | Tensor2D;
build(inputShape: Shape): void;
call(inputs: Tensor | Tensor[], kwargs?: PositionEmbeddingOptions): Tensor;
computeOutputShape(inputShape: Shape): Shape;
}

@@ -21,5 +21,7 @@ /**

/* Original source: keras_nlp/layers/modeling/position_embedding.py */
import { serialization } from '@tensorflow/tfjs-core';
import { serialization, tidy } from '@tensorflow/tfjs-core';
import { Layer } from '../../../engine/topology';
import { NotImplementedError } from '../../../errors';
import { ValueError } from '../../../errors';
import { getInitializer, serializeInitializer } from '../../../initializers';
import { getExactlyOneTensor } from '../../../utils/types_utils';
/**

@@ -61,13 +63,38 @@ * A layer which learns a position embedding for input sequences.

super(args);
throw new NotImplementedError('PositionEmbedding not implemented yet.');
if (args.sequenceLength == null) {
throw new ValueError('`sequenceLength` must be an Integer, received `null`.');
}
this.sequenceLength = args.sequenceLength;
this.initializer = getInitializer(args.initializer || 'glorotUniform');
}
getConfig() {
throw new NotImplementedError('Not implemented yet.');
const config = {
'sequenceLength': this.sequenceLength,
'initializer': serializeInitializer(this.initializer),
};
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
build(inputShape) {
throw new NotImplementedError('Not implemented yet.');
const featureSize = inputShape[inputShape.length - 1];
this.positionEmbeddings = this.addWeight('embeddings', [this.sequenceLength, featureSize], null, this.initializer, null, true);
super.build(inputShape);
}
call(inputs, kwargs = { startIndex: 0 }) {
throw new NotImplementedError('Not implemented yet.');
call(inputs, kwargs) {
return tidy(() => {
var _a;
kwargs.startIndex = (_a = kwargs.startIndex) !== null && _a !== void 0 ? _a : 0;
const shape = getExactlyOneTensor(inputs).shape;
const featureLength = shape[shape.length - 1];
const sequenceLength = shape[shape.length - 2];
// trim to match the length of the input sequence, which might be less
// than the sequence_length of the layer.
const positionEmbeddings = this.positionEmbeddings.read().slice([kwargs.startIndex, 0], [sequenceLength, featureLength]);
return positionEmbeddings.broadcastTo(shape);
});
}
computeOutputShape(inputShape) {
return inputShape;
}
}

@@ -78,2 +105,2 @@ /** @nocollapse */

serialization.registerClass(PositionEmbedding);
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoicG9zaXRpb25fZW1iZWRkaW5nLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1sYXllcnMvc3JjL2xheWVycy9ubHAvbW9kZWxpbmcvcG9zaXRpb25fZW1iZWRkaW5nLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVIOztHQUVHO0FBRUgsc0VBQXNFO0FBQ3RFLE9BQU8sRUFBOEIsYUFBYSxFQUFFLE1BQU0sdUJBQXVCLENBQUM7QUFHbEYsT0FBTyxFQUFFLEtBQUssRUFBYSxNQUFNLDBCQUEwQixDQUFDO0FBQzVELE9BQU8sRUFBRSxtQkFBbUIsRUFBRSxNQUFNLGlCQUFpQixDQUFDO0FBd0J0RDs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7R0FnQ0c7QUFDSCxNQUFhLGlCQUFrQixTQUFRLEtBQUs7SUFJMUMsWUFBWSxJQUEyQjtRQUNyQyxLQUFLLENBQUMsSUFBSSxDQUFDLENBQUM7UUFFWixNQUFNLElBQUksbUJBQW1CLENBQUMsd0NBQXdDLENBQUMsQ0FBQztJQUMxRSxDQUFDO0lBRVEsU0FBUztRQUNoQixNQUFNLElBQUksbUJBQW1CLENBQUMsc0JBQXNCLENBQUMsQ0FBQztJQUN4RCxDQUFDO0lBRVEsS0FBSyxDQUFDLFVBQTJCO1FBQ3hDLE1BQU0sSUFBSSxtQkFBbUIsQ0FBQyxzQkFBc0IsQ0FBQyxDQUFDO0lBQ3hELENBQUM7SUFFUSxJQUFJLENBQ1gsTUFBdUIsRUFDdkIsU0FBaUMsRUFBQyxVQUFVLEVBQUUsQ0FBQyxFQUFDO1FBRWhELE1BQU0sSUFBSSxtQkFBbUIsQ0FBQyxzQkFBc0IsQ0FBQyxDQUFDO0lBQ3hELENBQUM7O0FBdEJELGtCQUFrQjtBQUNGLDJCQUFTLEdBQUcsbUJBQW1CLENBQUM7U0FGckMsaUJBQWlCO0FBeUI5QixhQUFhLENBQUMsYUFBYSxDQUFDLGlCQUFpQixDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMyBHb29nbGUgTExDLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbi8qKlxuICogIFBvc2l0aW9uIGVtYmVkZGluZyBpbXBsZW1lbnRhdGlvbiBiYXNlZCBvbiBgdGYubGF5ZXJzLkxheWVyYC5cbiAqL1xuXG4vKiBPcmlnaW5hbCBzb3VyY2U6IGtlcmFzX25scC9sYXllcnMvbW9kZWxpbmcvcG9zaXRpb25fZW1iZWRkaW5nLnB5ICovXG5pbXBvcnQgeyBUZW5zb3IsIFRlbnNvcjFELCBUZW5zb3IyRCwgc2VyaWFsaXphdGlvbiB9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7IFNoYXBlIH0gZnJvbSAnLi4vLi4vLi4va2VyYXNfZm9ybWF0L2NvbW1vbic7XG5pbXBvcnQgeyBMYXllciwgTGF5ZXJBcmdzIH0gZnJvbSAnLi4vLi4vLi4vZW5naW5lL3RvcG9sb2d5JztcbmltcG9ydCB7IE5vdEltcGxlbWVudGVkRXJyb3IgfSBmcm9tICcuLi8uLi8uLi9lcnJvcnMnO1xuaW1wb3J0IHsgSW5pdGlhbGl6ZXJJZGVudGlmaWVyIH0gZnJvbSAnLi4vLi4vLi4vaW5pdGlhbGl6ZXJzJztcblxuZXhwb3J0IGRlY2xhcmUgaW50ZXJmYWNlIFBvc2l0aW9uRW1iZWRkaW5nQXJncyBleHRlbmRzIExheWVyQXJncyB7XG4gIC8qKlxuICAgKiBJbnRlZ2VyLiBUaGUgbWF4aW11bSBsZW5ndGggb2YgdGhlIGR5bmFtaWMgc2VxdWVuY2UuXG4gICAqL1xuICBzZXF1ZW5jZUxlbmd0aDogbnVtYmVyO1xuXG4gIC8qKlxuICAgKiBUaGUgaW5pdGlhbGl6ZXIgdG8gdXNlIGZvciB0aGUgZW1iZWRkaW5nIHdlaWdodHMuXG4gICAqIERlZmF1bHRzIHRvIGBcImdsb3JvdFVuaWZvcm1cImAuXG4gICAqL1xuICBpbml0aWFsaXplcj86IEluaXRpYWxpemVySWRlbnRpZmllcjtcbn1cblxuZXhwb3J0IGRlY2xhcmUgaW50ZXJmYWNlIFBvc2l0aW9uRW1iZWRkaW5nT3B0aW9ucyB7XG4gIC8qKlxuICAgKiBJbnRlZ2VyLiBJbmRleCB0byBzdGFydCB0aGUgcG9zaXRpb24gZW1iZWRkaW5ncyBhdC5cbiAgICogRGVmYXVsdHMgdG8gMC5cbiAgICovXG4gIHN0YXJ0SW5kZXg/OiBudW1iZXI7XG59XG5cbi8qKlxuICogQSBsYXllciB3aGljaCBsZWFybnMgYSBwb3NpdGlvbiBlbWJlZGRpbmcgZm9yIGlucHV0IHNlcXVlbmNlcy5cbiAqXG4gKiBUaGlzIGNsYXNzIGFzc3VtZXMgdGhhdCBpbiB0aGUgaW5wdXQgdGVuc29yLCB0aGUgbGFzdCBkaW1lbnNpb24gY29ycmVzcG9uZHNcbiAqIHRvIHRoZSBmZWF0dXJlcywgYW5kIHRoZSBkaW1lbnNpb24gYmVmb3JlIHRoZSBsYXN0IGNvcnJlc3BvbmRzIHRvIHRoZVxuICogc2VxdWVuY2UuXG4gKlxuICogRXhhbXBsZXM6XG4gKlxuICogQ2FsbGVkIGRpcmVjdGx5IG9uIGlucHV0LlxuICogYGBganNcbiAqIGNvbnN0IGxheWVyID0gbmV3IFBvc2l0aW9uRW1iZWRkaW5nKHtzZXF1ZW5jZUxlbmd0aD0xMH0pO1xuICogbGF5ZXIuY2FsbCh0Zi56ZXJvcyhbOCwgMTAsIDE2XSkpO1xuICogYGBgXG4gKlxuICogQ29tYmluZSB3aXRoIGEgdG9rZW4gZW1iZWRkaW5nLlxuICogYGBganNcbiAqIGNvbnN0IHNlcUxlbmd0aCA9IDUwO1xuICogY29uc3Qgdm9jYWJTaXplID0gNTAwMDtcbiAqIGNvbnN0IGVtYmVkRGltID0gMTI4O1xuICogY29uc3QgaW5wdXRzID0gdGYuaW5wdXQoe3NoYXBlOiBbc2VxTGVuZ3RoXX0pO1xuICogY29uc3QgdG9rZW5FbWJlZGRpbmdzID0gdGYubGF5ZXJzLmVtYmVkZGluZyh7XG4gKiAgICAgaW5wdXREaW09dm9jYWJTaXplLCBvdXRwdXREaW09ZW1iZWREaW1cbiAqIH0pLmFwcGx5KGlucHV0cyk7XG4gKiBjb25zdCBwb3NpdGlvbkVtYmVkZGluZ3MgPSBuZXcgUG9zaXRpb25FbWJlZGRpbmcoe1xuICogICAgIHNlcXVlbmNlTGVuZ3RoOiBzZXFMZW5ndGhcbiAqIH0pLmFwcGx5KHRva2VuRW1iZWRkaW5ncyk7XG4gKiBjb25zdCBvdXRwdXRzID0gdGYuYWRkKHRva2VuRW1iZWRkaW5ncywgcG9zaXRpb25FbWJlZGRpbmdzKTtcbiAqIGBgYFxuICpcbiAqIFJlZmVyZW5jZTpcbiAqICAtIFtEZXZsaW4gZXQgYWwuLCAyMDE5XShodHRwczovL2FyeGl2Lm9yZy9hYnMvMTgxMC4wNDgwNSlcbiAqL1xuZXhwb3J0IGNsYXNzIFBvc2l0aW9uRW1iZWRkaW5nIGV4dGVuZHMgTGF5ZXIge1xuICAvKiogQG5vY29sbGFwc2UgKi9cbiAgc3RhdGljIHJlYWRvbmx5IGNsYXNzTmFtZSA9ICdQb3NpdGlvbkVtYmVkZGluZyc7XG5cbiAgY29uc3RydWN0b3IoYXJnczogUG9zaXRpb25FbWJlZGRpbmdBcmdzKSB7XG4gICAgc3VwZXIoYXJncyk7XG5cbiAgICB0aHJvdyBuZXcgTm90SW1wbGVtZW50ZWRFcnJvcignUG9zaXRpb25FbWJlZGRpbmcgbm90IGltcGxlbWVudGVkIHlldC4nKTtcbiAgfVxuXG4gIG92ZXJyaWRlIGdldENvbmZpZygpOiBzZXJpYWxpemF0aW9uLkNvbmZpZ0RpY3Qge1xuICAgIHRocm93IG5ldyBOb3RJbXBsZW1lbnRlZEVycm9yKCdOb3QgaW1wbGVtZW50ZWQgeWV0LicpO1xuICB9XG5cbiAgb3ZlcnJpZGUgYnVpbGQoaW5wdXRTaGFwZTogU2hhcGUgfCBTaGFwZVtdKTogdm9pZCB7XG4gICAgdGhyb3cgbmV3IE5vdEltcGxlbWVudGVkRXJyb3IoJ05vdCBpbXBsZW1lbnRlZCB5ZXQuJyk7XG4gIH1cblxuICBvdmVycmlkZSBjYWxsKFxuICAgIGlucHV0czogVGVuc29yfFRlbnNvcltdLFxuICAgIGt3YXJnczogUG9zaXRpb25FbWJlZGRpbmdPcHRpb25zPXtzdGFydEluZGV4OiAwfVxuICApOiBUZW5zb3IxRHxUZW5zb3IyRCB7XG4gICAgdGhyb3cgbmV3IE5vdEltcGxlbWVudGVkRXJyb3IoJ05vdCBpbXBsZW1lbnRlZCB5ZXQuJyk7XG4gIH1cbn1cbnNlcmlhbGl6YXRpb24ucmVnaXN0ZXJDbGFzcyhQb3NpdGlvbkVtYmVkZGluZyk7XG4iXX0=
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"position_embedding.js","sourceRoot":"","sources":["../../../../../../../../tfjs-layers/src/layers/nlp/modeling/position_embedding.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,sEAAsE;AACtE,OAAO,EAAU,aAAa,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAGpE,OAAO,EAAE,KAAK,EAAa,MAAM,0BAA0B,CAAC;AAC5D,OAAO,EAAE,UAAU,EAAE,MAAM,iBAAiB,CAAC;AAC7C,OAAO,EAAsC,cAAc,EAAE,oBAAoB,EAAE,MAAM,uBAAuB,CAAC;AACjH,OAAO,EAAE,mBAAmB,EAAE,MAAM,4BAA4B,CAAC;AAwBjE;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAgCG;AACH,MAAa,iBAAkB,SAAQ,KAAK;IAO1C,YAAY,IAA2B;QACrC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,IAAI,CAAC,cAAc,IAAI,IAAI,EAAE;YAC/B,MAAM,IAAI,UAAU,CAClB,uDAAuD,CAAC,CAAC;SAC5D;QACD,IAAI,CAAC,cAAc,GAAG,IAAI,CAAC,cAAc,CAAC;QAC1C,IAAI,CAAC,WAAW,GAAG,cAAc,CAAC,IAAI,CAAC,WAAW,IAAI,eAAe,CAAC,CAAC;IACzE,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAAG;YACb,gBAAgB,EAAE,IAAI,CAAC,cAAc;YACrC,aAAa,EAAE,oBAAoB,CAAC,IAAI,CAAC,WAAW,CAAC;SACtD,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;IAEQ,KAAK,CAAC,UAAiB;QAC9B,MAAM,WAAW,GAAG,UAAU,CAAC,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QACtD,IAAI,CAAC,kBAAkB,GAAG,IAAI,CAAC,SAAS,CACtC,YAAY,EACZ,CAAC,IAAI,CAAC,cAAc,EAAE,WAAW,CAAC,EAClC,IAAI,EACJ,IAAI,CAAC,WAAW,EAChB,IAAI,EACJ,IAAI,CACL,CAAC;QACF,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC1B,CAAC;IAEQ,IAAI,CACX,MAAuB,EACvB,MAAiC;QAEjC,OAAO,IAAI,CAAC,GAAG,EAAE;;YACf,MAAM,CAAC,UAAU,GAAG,MAAA,MAAM,CAAC,UAAU,mCAAI,CAAC,CAAC;YAC3C,MAAM,KAAK,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC;YAChD,MAAM,aAAa,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YAC9C,MAAM,cAAc,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YAC/C,sEAAsE;YACtE,yCAAyC;YACzC,MAAM,kBAAkB,GAAG,IAAI,CAAC,kBAAkB,CAAC,IAAI,EAAE,CAAC,KAAK,CAC7D,CAAC,MAAM,CAAC,UAAU,EAAE,CAAC,CAAC,EAAE,CAAC,cAAc,EAAE,aAAa,CAAC,CAAC,CAAC;YAC3D,OAAO,kBAAkB,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC;QAC/C,CAAC,CAAC,CAAC;IACL,CAAC;IAEQ,kBAAkB,CAAC,UAAiB;QAC3C,OAAO,UAAU,CAAC;IACpB,CAAC;;AA1DD,kBAAkB;AACF,2BAAS,GAAG,mBAAmB,CAAC;SAFrC,iBAAiB;AA6D9B,aAAa,CAAC,aAAa,CAAC,iBAAiB,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Position embedding implementation based on `tf.layers.Layer`.\n */\n\n/* Original source: keras_nlp/layers/modeling/position_embedding.py */\nimport { Tensor, serialization, tidy } from '@tensorflow/tfjs-core';\n\nimport { Shape } from '../../../keras_format/common';\nimport { Layer, LayerArgs } from '../../../engine/topology';\nimport { ValueError } from '../../../errors';\nimport { Initializer, InitializerIdentifier, getInitializer, serializeInitializer } from '../../../initializers';\nimport { getExactlyOneTensor } from '../../../utils/types_utils';\nimport { LayerVariable } from '../../../variables';\n\nexport declare interface PositionEmbeddingArgs extends LayerArgs {\n  /**\n   * Integer. The maximum length of the dynamic sequence.\n   */\n  sequenceLength: number;\n\n  /**\n   * The initializer to use for the embedding weights.\n   * Defaults to `\"glorotUniform\"`.\n   */\n  initializer?: Initializer|InitializerIdentifier;\n}\n\nexport declare interface PositionEmbeddingOptions {\n  /**\n   * Integer. Index to start the position embeddings at.\n   * Defaults to 0.\n   */\n  startIndex?: number;\n}\n\n/**\n * A layer which learns a position embedding for input sequences.\n *\n * This class assumes that in the input tensor, the last dimension corresponds\n * to the features, and the dimension before the last corresponds to the\n * sequence.\n *\n * Examples:\n *\n * Called directly on input.\n * ```js\n * const layer = new PositionEmbedding({sequenceLength=10});\n * layer.call(tf.zeros([8, 10, 16]));\n * ```\n *\n * Combine with a token embedding.\n * ```js\n * const seqLength = 50;\n * const vocabSize = 5000;\n * const embedDim = 128;\n * const inputs = tf.input({shape: [seqLength]});\n * const tokenEmbeddings = tf.layers.embedding({\n *     inputDim=vocabSize, outputDim=embedDim\n * }).apply(inputs);\n * const positionEmbeddings = new PositionEmbedding({\n *     sequenceLength: seqLength\n * }).apply(tokenEmbeddings);\n * const outputs = tf.add(tokenEmbeddings, positionEmbeddings);\n * ```\n *\n * Reference:\n *  - [Devlin et al., 2019](https://arxiv.org/abs/1810.04805)\n */\nexport class PositionEmbedding extends Layer {\n  /** @nocollapse */\n  static readonly className = 'PositionEmbedding';\n  private sequenceLength: number;\n  private initializer: Initializer;\n  protected positionEmbeddings: LayerVariable;\n\n  constructor(args: PositionEmbeddingArgs) {\n    super(args);\n    if (args.sequenceLength == null) {\n      throw new ValueError(\n        '`sequenceLength` must be an Integer, received `null`.');\n    }\n    this.sequenceLength = args.sequenceLength;\n    this.initializer = getInitializer(args.initializer || 'glorotUniform');\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config = {\n      'sequenceLength': this.sequenceLength,\n      'initializer': serializeInitializer(this.initializer),\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n\n  override build(inputShape: Shape): void {\n    const featureSize = inputShape[inputShape.length - 1];\n    this.positionEmbeddings = this.addWeight(\n      'embeddings',\n      [this.sequenceLength, featureSize],\n      null,\n      this.initializer,\n      null,\n      true\n    );\n    super.build(inputShape);\n  }\n\n  override call(\n    inputs: Tensor|Tensor[],\n    kwargs?: PositionEmbeddingOptions\n  ): Tensor {\n    return tidy(() => {\n      kwargs.startIndex = kwargs.startIndex ?? 0;\n      const shape = getExactlyOneTensor(inputs).shape;\n      const featureLength = shape[shape.length - 1];\n      const sequenceLength = shape[shape.length - 2];\n      // trim to match the length of the input sequence, which might be less\n      // than the sequence_length of the layer.\n      const positionEmbeddings = this.positionEmbeddings.read().slice(\n        [kwargs.startIndex, 0], [sequenceLength, featureLength]);\n      return positionEmbeddings.broadcastTo(shape);\n    });\n  }\n\n  override computeOutputShape(inputShape: Shape): Shape {\n    return inputShape;\n  }\n}\nserialization.registerClass(PositionEmbedding);\n"]}

@@ -21,7 +21,11 @@ /**

*/
import { Tensor, Tensor1D, Tensor2D, serialization } from '@tensorflow/tfjs-core';
import { Layer, LayerArgs } from '../../../engine/topology';
import { InitializerIdentifier } from '../../../initializers';
import { Tensor, serialization } from '@tensorflow/tfjs-core';
import { Activation } from '../../../activations';
import { Layer, LayerArgs, SymbolicTensor } from '../../../engine/topology';
import { Initializer, InitializerIdentifier } from '../../../initializers';
import { ActivationIdentifier } from '../../../keras_format/activation_config';
import { Shape } from '../../../keras_format/common';
import { Dense, Dropout } from '../../core';
import { LayerNormalization } from '../../normalization';
import { CachedMultiHeadAttention } from './cached_multihead_attention';
export declare interface TransformerDecoderArgs extends LayerArgs {

@@ -45,3 +49,3 @@ /**

*/
activation?: ActivationIdentifier;
activation?: Activation | ActivationIdentifier;
/**

@@ -56,3 +60,3 @@ * The eps value in layer normalization components.

*/
kernelInitializer?: InitializerIdentifier;
kernelInitializer?: Initializer | InitializerIdentifier;
/**

@@ -62,3 +66,3 @@ * The bias initializer for the dense and multiheaded attention layers.

*/
biasInitializer?: InitializerIdentifier;
biasInitializer?: Initializer | InitializerIdentifier;
/**

@@ -71,3 +75,3 @@ * If true, the inputs to the attention layer(s) and the intermediate dense

*/
normalizeFirst: boolean;
normalizeFirst?: boolean;
}

@@ -83,3 +87,3 @@ export declare interface TransformerDecoderOptions {

*/
encoderSequence?: Tensor;
encoderSequence?: Tensor | SymbolicTensor;
/**

@@ -89,3 +93,3 @@ * A boolean Tensor, the padding mask of decoder sequence, must be of shape

*/
decoderPaddingMask: Tensor;
decoderPaddingMask?: Tensor | SymbolicTensor;
/**

@@ -116,3 +120,3 @@ * A boolean Tensor. Customized decoder sequence mask, must be of shape

*/
selfAttentionCacheUpdateIndex?: number | Tensor;
selfAttentionCacheUpdateIndex?: number;
/**

@@ -129,3 +133,3 @@ * A dense float Tensor. The cache of key/value pairs in the cross-attention

*/
crossAttentionCacheUpdateIndex?: number | Tensor;
crossAttentionCacheUpdateIndex?: number;
/**

@@ -188,2 +192,22 @@ * If true, a causal mask (masking out future input) is applied on the decoder

static readonly className = "TransformerDecoder";
protected intermediateDim: number;
protected numHeads: number;
protected dropout: number;
protected activation: Activation;
protected layerNormEpsilon: number;
protected kernelInitializer: Initializer;
protected biasInitializer: Initializer;
protected normalizeFirst: boolean;
protected decoderSequenceShape: Shape;
protected encoderSequenceShape: Shape;
protected selfAttentionLayer: CachedMultiHeadAttention;
protected selfAttentionLayernorm: LayerNormalization;
protected selfAttentionDropout: Dropout;
protected selfCrossAttentionLayer: CachedMultiHeadAttention;
protected selfCrossAttentionLayernorm: LayerNormalization;
protected selfCrossAttentionDropout: Dropout;
protected feedforwardIntermediateDense: Dense;
protected feedforwardOutputDense: Dense;
protected feedforwardLayernorm: LayerNormalization;
protected feedforwardDropout: Dropout;
constructor(args: TransformerDecoderArgs);

@@ -196,5 +220,7 @@ /**

build(inputShape: Shape | [Shape, Shape]): void;
apply(inputs: Tensor | Tensor[], kwargs?: TransformerDecoderOptions): Tensor | Tensor[];
call(decoderSequence: Tensor, kwargs: TransformerDecoderOptions): Tensor | Tensor[];
apply(decoderSequence: Tensor | SymbolicTensor, kwargs?: TransformerDecoderOptions): Tensor | SymbolicTensor;
call(decoderSequence: Tensor, kwargs: TransformerDecoderOptions): Tensor;
/**
* Forward pass of the TransformerDecoder.
*
* @returns One of three things, depending on call arguments:

@@ -208,3 +234,3 @@ * - `[outputs, null, null]`, if `selfAttentionCache` is `null`.

*/
callAndReturnCaches(decoderSequence: Tensor, kwargs: TransformerDecoderOptions): [Tensor1D | Tensor2D, Tensor1D | Tensor2D, Tensor1D | Tensor2D];
callAndReturnCaches(decoderSequence: Tensor, kwargs: TransformerDecoderOptions): [Tensor, Tensor, Tensor];
private computeSelfAttentionMask;

@@ -211,0 +237,0 @@ getConfig(): serialization.ConfigDict;

@@ -21,5 +21,11 @@ /**

/* Original source: keras_nlp/layers/modeling/transformer_decoder.py */
import { serialization } from '@tensorflow/tfjs-core';
import { add, serialization, tidy } from '@tensorflow/tfjs-core';
import { getActivation, serializeActivation } from '../../../activations';
import { Layer, } from '../../../engine/topology';
import { NotImplementedError } from '../../../errors';
import { ValueError } from '../../../errors';
import { getInitializer, serializeInitializer } from '../../../initializers';
import { Dense, Dropout } from '../../core';
import { LayerNormalization } from '../../normalization';
import { CachedMultiHeadAttention } from './cached_multihead_attention';
import { computeCausalMask, mergePaddingAndAttentionMask } from './transformer_layer_utils';
/**

@@ -74,4 +80,13 @@ * Transformer decoder.

constructor(args) {
var _a, _b, _c, _d, _e, _f;
super(args);
throw new NotImplementedError(`Not implemented yet.`);
this.intermediateDim = args.intermediateDim;
this.numHeads = args.numHeads;
this.dropout = (_a = args.dropout) !== null && _a !== void 0 ? _a : 0;
this.activation = getActivation((_b = args.activation) !== null && _b !== void 0 ? _b : 'relu');
this.layerNormEpsilon = (_c = args.layerNormEpsilon) !== null && _c !== void 0 ? _c : 1e-05;
this.kernelInitializer =
getInitializer((_d = args.kernelInitializer) !== null && _d !== void 0 ? _d : 'glorotUniform');
this.biasInitializer = getInitializer((_e = args.biasInitializer) !== null && _e !== void 0 ? _e : 'zeros');
this.normalizeFirst = (_f = args.normalizeFirst) !== null && _f !== void 0 ? _f : false;
}

@@ -84,6 +99,59 @@ /**

build(inputShape) {
throw new NotImplementedError(`Not implemented yet.`);
if (Array.isArray(inputShape[0])) {
// `inputShape` is of type [Shape, Shape].
[this.decoderSequenceShape, this.encoderSequenceShape] =
inputShape;
}
else {
this.decoderSequenceShape = inputShape;
}
// Infer the dimension of our hidden feature size from the build shape.
const hiddenDim = this.decoderSequenceShape[this.decoderSequenceShape.length - 1];
// Attention head size is `hiddenDim` over the number of heads.
const headDim = Math.floor(hiddenDim / this.numHeads);
// Self attention layers.
this.selfAttentionLayer = new CachedMultiHeadAttention({
numHeads: this.numHeads,
keyDim: headDim,
dropout: this.dropout,
kernelInitializer: getInitializer(this.kernelInitializer.getClassName()),
biasInitializer: getInitializer(this.biasInitializer.getClassName()),
});
this.selfAttentionLayer.buildFromSignature(this.decoderSequenceShape, this.decoderSequenceShape);
this.selfAttentionLayernorm =
new LayerNormalization({ epsilon: this.layerNormEpsilon });
this.selfAttentionLayernorm.build(this.decoderSequenceShape);
this.selfAttentionDropout = new Dropout({ rate: this.dropout });
// Cross attention layers are optional.
// TODO(pforderique): Add cross attention layers.
// Feedforward layers.
this.feedforwardIntermediateDense = new Dense({
units: this.intermediateDim,
activation: this.activation.getClassName(),
kernelInitializer: getInitializer(this.kernelInitializer.getClassName()),
biasInitializer: getInitializer(this.biasInitializer.getClassName()),
});
this.feedforwardIntermediateDense.build(this.decoderSequenceShape);
this.feedforwardOutputDense = new Dense({
units: hiddenDim,
kernelInitializer: getInitializer(this.kernelInitializer.getClassName()),
biasInitializer: getInitializer(this.biasInitializer.getClassName()),
});
const intermediateShape = this.decoderSequenceShape.slice();
intermediateShape[intermediateShape.length - 1] = this.intermediateDim;
this.feedforwardOutputDense.build(intermediateShape);
this.feedforwardLayernorm =
new LayerNormalization({ epsilon: this.layerNormEpsilon });
this.feedforwardLayernorm.build(this.decoderSequenceShape);
this.feedforwardDropout = new Dropout({ rate: this.dropout });
// Create layers based on input shape.
this.built = true;
}
apply(inputs, kwargs) {
throw new NotImplementedError(`Not implemented yet.`);
apply(decoderSequence, kwargs) {
if (!this.built) {
const decoderSequenceShape = decoderSequence.shape;
const encoderSequenceShape = kwargs && kwargs.encoderSequence ? kwargs.encoderSequence.shape : null;
this.build([decoderSequenceShape, encoderSequenceShape]);
}
return super.apply(decoderSequence, kwargs);
}

@@ -94,2 +162,4 @@ call(decoderSequence, kwargs) {

/**
* Forward pass of the TransformerDecoder.
*
* @returns One of three things, depending on call arguments:

@@ -104,12 +174,111 @@ * - `[outputs, null, null]`, if `selfAttentionCache` is `null`.

callAndReturnCaches(decoderSequence, kwargs) {
throw new NotImplementedError(`Not implemented yet. Uses ${this.computeSelfAttentionMask}`);
return tidy(() => {
const hasEncoderSequence = kwargs.encoderSequence != null;
const hasCrossAttention = this.selfCrossAttentionLayer != null;
if (!hasCrossAttention && hasEncoderSequence) {
throw new ValueError('The number of call arguments to `TransformerDecoder` should ' +
'not change. Use `layer.apply(decoderSequence, {encoderSequence})` ' +
'to build a layer with cross attention, or ' +
'`layer.apply (decoderSequence)` to build a layer without. ' +
'This layer has been built without cross attention, but ' +
'you are trying to call it with encoderSequence.');
}
else if (hasCrossAttention && !hasEncoderSequence) {
throw new ValueError('The number of call arguments to `TransformerDecoder` should not ' +
'change. Use `layer.apply(decoderSequence, {encoderSequence})` ' +
'to build a layer with cross attention, or ' +
'`layer.apply(decoderSequence)` to build a layer without. ' +
'This layer has been built with cross attention, but ' +
'you did not provide encoderSequence.');
}
const hasSelfAttentionCache = kwargs.selfAttentionCache != null;
const hasCrossAttentionCache = kwargs.crossAttentionCache != null;
if (hasCrossAttention && (hasSelfAttentionCache !== hasCrossAttentionCache)) {
throw new ValueError('When calling `TransformerDecoder` with cross-attention (with both ' +
'`encoderSequence` and `decoderSequence`), `selfAttentionCache` ' +
'and `crossAttentionCache` should both be set or both be `null`. ' +
'One cannot be `null` while the other is not. Received: ' +
`selfAttentionCache=${kwargs.selfAttentionCache}, ` +
`crossAttentionCache=${kwargs.crossAttentionCache}.`);
}
const selfAttentionMask = this.computeSelfAttentionMask(decoderSequence, kwargs.decoderPaddingMask, kwargs.decoderAttentionMask, kwargs.useCausalMask, kwargs.selfAttentionCache, kwargs.selfAttentionCacheUpdateIndex);
let x = decoderSequence; // Intermediate result.
let selfAttentionCache = kwargs.selfAttentionCache;
// Self attention block.
let residual = x;
if (this.normalizeFirst) {
x = this.selfAttentionLayernorm.apply(x);
}
[x, selfAttentionCache] = this.selfAttentionLayer.callAndReturnCache(x, {
value: x,
attentionMask: selfAttentionMask,
cache: selfAttentionCache,
cacheUpdateIndex: kwargs.selfAttentionCacheUpdateIndex,
});
x = this.selfAttentionDropout.apply(x);
x = add(x, residual);
if (!this.normalizeFirst) {
x = this.selfAttentionLayernorm.apply(x);
}
// Cross attention is optional.
// TODO(pforderique): Add cross attention logic for encoder-decoder arch.
// Feedforward block.
residual = x;
if (this.normalizeFirst) {
x = this.selfAttentionLayernorm.apply(x);
}
x = this.feedforwardIntermediateDense.apply(x);
x = this.feedforwardOutputDense.apply(x);
x = this.feedforwardDropout.apply(x);
x = add(x, residual);
if (!this.normalizeFirst) {
x = this.selfAttentionLayernorm.apply(x);
}
if (selfAttentionCache != null) {
if (hasCrossAttention) {
return [x, selfAttentionCache, kwargs.crossAttentionCache];
}
else {
return [x, selfAttentionCache, null];
}
}
return [x, null, null];
});
}
computeSelfAttentionMask(decoderSequence, decoderPaddingMask, decoderAttentionMask, useCasualMask, selfAttentionCache, selfAttentionCacheUpdateIndex) {
throw new NotImplementedError(`Not implemented yet.`);
const decoderMask = mergePaddingAndAttentionMask(decoderSequence, decoderPaddingMask, decoderAttentionMask);
if (useCasualMask) {
const batchSize = decoderSequence.shape[0];
let inputLength = decoderSequence.shape[1];
const outputLength = decoderSequence.shape[1];
// We need to handle a rectangular causal mask when doing cached
// decoding. For generative inference, `decoderSequence` will
// generally be length 1, and `cache` will be the full generation length.
if (selfAttentionCache != null) {
inputLength = selfAttentionCache.shape[2];
}
const causalMask = computeCausalMask(batchSize, inputLength, outputLength, selfAttentionCacheUpdateIndex !== null && selfAttentionCacheUpdateIndex !== void 0 ? selfAttentionCacheUpdateIndex : 0);
return decoderMask != null ? decoderMask.minimum(causalMask) : causalMask;
}
return decoderMask;
}
getConfig() {
throw new NotImplementedError(`Not implemented yet.`);
const config = {
'intermediateDim': this.intermediateDim,
'numHeads': this.numHeads,
'dropout': this.dropout,
'activation': serializeActivation(this.activation),
'layerNormEpsilon': this.layerNormEpsilon,
'kernelInitializer': serializeInitializer(this.kernelInitializer),
'biasInitializer': serializeInitializer(this.biasInitializer),
'normalizeFirst': this.normalizeFirst,
'decoderSequenceShape': this.decoderSequenceShape,
'encoderSequenceShape': this.encoderSequenceShape,
};
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
computeOutputShape(decoderSequenceShape) {
throw new NotImplementedError(`Not implemented yet.`);
return decoderSequenceShape;
}

@@ -121,2 +290,2 @@ }

serialization.registerClass(TransformerDecoder);
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"transformer_decoder.js","sourceRoot":"","sources":["../../../../../../../../tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,uEAAuE;AACvE,OAAO,EAA8B,aAAa,EAAE,MAAM,uBAAuB,CAAC;AAElF,OAAO,EAAE,KAAK,GAAc,MAAM,0BAA0B,CAAC;AAC7D,OAAO,EAAE,mBAAmB,EAAE,MAAM,iBAAiB,CAAC;AA+HtD;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA8CG;AACH,MAAa,kBAAmB,SAAQ,KAAK;IAI3C,YAAY,IAA4B;QACtC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,MAAM,IAAI,mBAAmB,CAAC,sBAAsB,CAAC,CAAC;IACxD,CAAC;IAED;;;;OAIG;IACM,KAAK,CAAC,UAAgC;QAC7C,MAAM,IAAI,mBAAmB,CAAC,sBAAsB,CAAC,CAAC;IACxD,CAAC;IAEQ,KAAK,CACZ,MAAuB,EAAE,MAAkC;QAE3D,MAAM,IAAI,mBAAmB,CAAC,sBAAsB,CAAC,CAAC;IACxD,CAAC;IAEQ,IAAI,CACX,eAAuB,EAAE,MAAiC;QAE1D,OAAO,IAAI,CAAC,mBAAmB,CAAC,eAAe,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;IAC9D,CAAC;IAED;;;;;;;;OAQG;IACH,mBAAmB,CACjB,eAAuB,EAAE,MAAiC;QAE1D,MAAM,IAAI,mBAAmB,CAC3B,6BAA6B,IAAI,CAAC,wBAAwB,EAAE,CAAC,CAAC;IAClE,CAAC;IAEO,wBAAwB,CAC9B,eAAuB,EACvB,kBAA0B,EAC1B,oBAA4B,EAC5B,aAAsB,EACtB,kBAA0B,EAC1B,6BAA4C;QAE5C,MAAM,IAAI,mBAAmB,CAAC,sBAAsB,CAAC,CAAC;IACxD,CAAC;IAEQ,SAAS;QAChB,MAAM,IAAI,mBAAmB,CAAC,sBAAsB,CAAC,CAAC;IACxD,CAAC;IAEQ,kBAAkB,CAAC,oBAA2B;QACrD,MAAM,IAAI,mBAAmB,CAAC,sBAAsB,CAAC,CAAC;IACxD,CAAC;;AA9DD,kBAAkB;AACF,4BAAS,GAAG,oBAAoB,CAAC;SAFtC,kBAAkB;AAiE/B,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Transformer decoder block implementation based on TFJS `Layer`.\n */\n\n/* Original source: keras_nlp/layers/modeling/transformer_decoder.py */\nimport { Tensor, Tensor1D, Tensor2D, serialization } from '@tensorflow/tfjs-core';\n\nimport { Layer, LayerArgs, } from '../../../engine/topology';\nimport { NotImplementedError } from '../../../errors';\nimport { InitializerIdentifier } from '../../../initializers';\nimport { ActivationIdentifier } from '../../../keras_format/activation_config';\nimport { Shape } from '../../../keras_format/common';\n\nexport declare interface TransformerDecoderArgs extends LayerArgs {\n  /**\n   * Integer. The hidden size of feedforward network.\n   */\n  intermediateDim: number;\n\n  /**\n   * Integer. The number of heads in MultiHeadAttention.\n   */\n  numHeads: number;\n\n  /**\n   * The dropout value, shared by MultiHeadAttention and feedforward network.\n   * Defaults to `0.`.\n   */\n  dropout?: number;\n\n  /**\n   * The activation function of feedforward network.\n   * Defaults to `\"relu\"`.\n   */\n  activation?: ActivationIdentifier;\n\n  /**\n   * The eps value in layer normalization components.\n   * Defaults to `1e-5`.\n   */\n  layerNormEpsilon?: number;\n\n  /**\n   * The kernel initializer for the dense and multiheaded attention layers.\n   * Defaults to `\"glorotUniform\"`.\n   */\n  kernelInitializer?: InitializerIdentifier;\n\n  /**\n   * The bias initializer for the dense and multiheaded attention layers.\n   * Defaults to `\"zeros\"`.\n   */\n  biasInitializer?: InitializerIdentifier;\n\n  /**\n   * If true, the inputs to the attention layer(s) and the intermediate dense\n   * layer are normalized (similar to GPT-2). If set to false, outputs of\n   * attention layer and intermediate dense layer are normalized\n   * (similar to BERT).\n   * Defaults to `false`.\n   */\n  normalizeFirst: boolean;\n}\n\nexport declare interface TransformerDecoderOptions {\n  /**\n   * decoderSequence: The decode input sequence.\n   */\n\n  /**\n   * The encoder input sequence. For decoder only models (like GPT2), this\n   * should be left `null`. Once the model is called without an encoderSequence,\n   * you cannot call it again with encoderSequence.\n   */\n  encoderSequence?: Tensor;\n\n  /**\n   * A boolean Tensor, the padding mask of decoder sequence, must be of shape\n   * `[batchSize, decoderSequenceLength]`.\n   */\n  decoderPaddingMask: Tensor;\n\n  /**\n   * A boolean Tensor. Customized decoder sequence mask, must be of shape\n   * `[batchSize, decoderSequenceLength, decoderSequenceLength]`.\n   */\n  decoderAttentionMask?: Tensor;\n\n  /**\n   * A boolean Tensor, the padding mask of encoder sequence, must be of shape\n   * `[batchSize, encoderSequenceLength]`.\n   */\n  encoderPaddingMask?: Tensor;\n\n  /**\n   * A boolean Tensor. Customized encoder sequence mask, must be of shape\n   * `[batchSize, encoderSequenceLength, encoderSequenceLength]`.\n   */\n  encoderAttentionMask?: Tensor;\n\n  /**\n   * A dense float Tensor. The cache of key/values pairs in the self-attention\n   * layer. Has shape `[batchSize, 2, maxSeqLen, numHeads, keyDims]`.\n   */\n  selfAttentionCache?: Tensor;\n\n  /**\n   * Integer or Integer Tensor. The index at which to update the\n   * `selfAttentionCache`. Usually, this is the index of the current token\n   * being processed during decoding.\n   */\n  selfAttentionCacheUpdateIndex?: number|Tensor;\n\n  /**\n   * A dense float Tensor. The cache of key/value pairs in the cross-attention\n   * layer. Has shape `[batchSize, 2, S, numHeads, keyDims]`.\n   */\n  crossAttentionCache?: Tensor;\n\n  /**\n   * Integer or Integer Tensor. The index at which to update the\n   * `crossAttentionCache`. Usually, this is either `0` (compute the entire\n   * `crossAttentionCache`), or `null` (reuse a previously computed\n   * `crossAttentionCache`).\n   */\n  crossAttentionCacheUpdateIndex?: number|Tensor;\n\n  /**\n   * If true, a causal mask (masking out future input) is applied on the decoder\n   * sequence.\n   * Defaults to `true`.\n   */\n  useCausalMask?: boolean;\n}\n\n/**\n * Transformer decoder.\n *\n * This class follows the architecture of the transformer decoder layer in the\n * paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users\n * can instantiate multiple instances of this class to stack up a decoder.\n *\n * By default, this layer will apply a causal mask to the decoder attention\n * layer. This layer will correctly compute an attention mask from an implicit\n * padding mask (for example, by passing `maskZero=true` to a\n * `tf.layers.embedding` layer). See the Masking and Padding\n * [guide](https://keras.io/guides/understanding_masking_and_padding/)\n * for more details.\n *\n * This layer can be called with either one or two inputs. The number of inputs\n * must be consistent across all calls. The options are as follows:\n *    `layer.call(decoderSequence)`: no cross-attention will be built into the\n *         decoder block. This is useful when building a \"decoder-only\"\n *         transformer such as GPT-2.\n *    `layer.call(decoderSequence, {encoderSequence})`: cross-attention will be\n *         built into the decoder block. This is useful when building an\n *         \"encoder-decoder\" transformer, such as the original transformer\n *         model described in Attention is All You Need.\n *\n * Examples:\n * ```js\n * // Create a single transformer decoder layer.\n * const decoder = new TransformerDecoder({intermediateDim: 64, numHeads: 8});\n *\n * // Create a simple model containing the decoder.\n * const decoderInput = tf.input({shape: [10, 64]});\n * const encoderInput = tf.input({shape: {[10, 64]});\n * const output = decoder.call(decoderInput, {encoderInput});\n * const model = tf.model({\n *     inputs: [decoderInput, encoderInput],\n *     outputs: output,\n * );\n *\n * // Call decoder on the inputs.\n * const decoderInputData = tf.randomUniform([2, 10, 64]);\n * const encoderInputData = tf.randomUniform([2, 10, 64]);\n * const decoderOutput = model.predict([decoderInputData, encoderInputData]);\n * ```\n *\n * References:\n *  - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)\n */\nexport class TransformerDecoder extends Layer {\n  /** @nocollapse */\n  static readonly className = 'TransformerDecoder';\n\n  constructor(args: TransformerDecoderArgs) {\n    super(args);\n    throw new NotImplementedError(`Not implemented yet.`);\n  }\n\n  /**\n   *\n   * @param inputShape decoderSequenceShape or\n   *  [decoderSequenceShape, encoderSequenceShape]\n   */\n  override build(inputShape: Shape|[Shape, Shape]): void {\n    throw new NotImplementedError(`Not implemented yet.`);\n  }\n\n  override apply(\n    inputs: Tensor|Tensor[], kwargs?: TransformerDecoderOptions\n  ): Tensor | Tensor[] {\n    throw new NotImplementedError(`Not implemented yet.`);\n  }\n\n  override call(\n    decoderSequence: Tensor, kwargs: TransformerDecoderOptions\n  ): Tensor|Tensor[] {\n    return this.callAndReturnCaches(decoderSequence, kwargs)[0];\n  }\n\n  /**\n   * @returns One of three things, depending on call arguments:\n   *   - `[outputs, null, null]`, if `selfAttentionCache` is `null`.\n   *   - `[outputs, selfAttentionCache, null]`, if `selfAttentionCache` is\n   *     set and the layer has no cross-attention.\n   *   - `[outputs, selfAttentionCache, crossAttentionCache]`, if\n   *     `selfAttentionCache` and `crossAttentionCache` are set and\n   *     the layer has cross-attention.\n   */\n  callAndReturnCaches(\n    decoderSequence: Tensor, kwargs: TransformerDecoderOptions\n  ): [Tensor1D|Tensor2D, Tensor1D|Tensor2D, Tensor1D|Tensor2D] {\n    throw new NotImplementedError(\n      `Not implemented yet. Uses ${this.computeSelfAttentionMask}`);\n  }\n\n  private computeSelfAttentionMask(\n    decoderSequence: Tensor,\n    decoderPaddingMask: Tensor,\n    decoderAttentionMask: Tensor,\n    useCasualMask: boolean,\n    selfAttentionCache: Tensor,\n    selfAttentionCacheUpdateIndex: number|Tensor\n  ): Tensor {\n    throw new NotImplementedError(`Not implemented yet.`);\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    throw new NotImplementedError(`Not implemented yet.`);\n  }\n\n  override computeOutputShape(decoderSequenceShape: Shape): Shape {\n    throw new NotImplementedError(`Not implemented yet.`);\n  }\n}\nserialization.registerClass(TransformerDecoder);\n"]}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"transformer_decoder.js","sourceRoot":"","sources":["../../../../../../../../tfjs-layers/src/layers/nlp/modeling/transformer_decoder.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,uEAAuE;AACvE,OAAO,EAAU,GAAG,EAAE,aAAa,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAEzE,OAAO,EAAc,aAAa,EAAE,mBAAmB,EAAE,MAAM,sBAAsB,CAAC;AACtF,OAAO,EAAE,KAAK,GAA8B,MAAM,0BAA0B,CAAC;AAC7E,OAAO,EAAE,UAAU,EAAE,MAAM,iBAAiB,CAAC;AAC7C,OAAO,EAAsC,cAAc,EAAE,oBAAoB,EAAE,MAAM,uBAAuB,CAAC;AAGjH,OAAO,EAAE,KAAK,EAAE,OAAO,EAAE,MAAM,YAAY,CAAC;AAC5C,OAAO,EAAE,kBAAkB,EAAE,MAAM,qBAAqB,CAAC;AAEzD,OAAO,EAAE,wBAAwB,EAAE,MAAM,8BAA8B,CAAC;AACxE,OAAO,EAAE,iBAAiB,EAAE,4BAA4B,EAAE,MAAM,2BAA2B,CAAC;AA4H5F;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA8CG;AACH,MAAa,kBAAmB,SAAQ,KAAK;IA4B3C,YAAY,IAA4B;;QACtC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC,eAAe,CAAC;QAC5C,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAC9B,IAAI,CAAC,OAAO,GAAG,MAAA,IAAI,CAAC,OAAO,mCAAI,CAAC,CAAC;QACjC,IAAI,CAAC,UAAU,GAAG,aAAa,CAAC,MAAA,IAAI,CAAC,UAAU,mCAAI,MAAM,CAAC,CAAC;QAC3D,IAAI,CAAC,gBAAgB,GAAG,MAAA,IAAI,CAAC,gBAAgB,mCAAI,KAAK,CAAC;QACvD,IAAI,CAAC,iBAAiB;YACpB,cAAc,CAAC,MAAA,IAAI,CAAC,iBAAiB,mCAAI,eAAe,CAAC,CAAC;QAC5D,IAAI,CAAC,eAAe,GAAG,cAAc,CAAC,MAAA,IAAI,CAAC,eAAe,mCAAI,OAAO,CAAC,CAAC;QACvE,IAAI,CAAC,cAAc,GAAG,MAAA,IAAI,CAAC,cAAc,mCAAI,KAAK,CAAC;IACrD,CAAC;IAED;;;;OAIG;IACM,KAAK,CAAC,UAAgC;QAC7C,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE;YAChC,0CAA0C;YAC1C,CAAC,IAAI,CAAC,oBAAoB,EAAE,IAAI,CAAC,oBAAoB,CAAC;gBACpD,UAA4B,CAAC;SAChC;aAAM;YACL,IAAI,CAAC,oBAAoB,GAAG,UAAmB,CAAC;SACjD;QACD,uEAAuE;QACvE,MAAM,SAAS,GACb,IAAI,CAAC,oBAAoB,CAAC,IAAI,CAAC,oBAAoB,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QAClE,+DAA+D;QAC/D,MAAM,OAAO,GAAG,IAAI,CAAC,KAAK,CAAC,SAAS,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC;QAEtD,yBAAyB;QACzB,IAAI,CAAC,kBAAkB,GAAG,IAAI,wBAAwB,CAAC;YACrD,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,MAAM,EAAE,OAAO;YACf,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,iBAAiB,EAAE,cAAc,CAAC,IAAI,CAAC,iBAAiB,CAAC,YAAY,EAAE,CAAC;YACxE,eAAe,EAAE,cAAc,CAAC,IAAI,CAAC,eAAe,CAAC,YAAY,EAAE,CAAC;SACrE,CAAC,CAAC;QAEH,IAAI,CAAC,kBAAkB,CAAC,kBAAkB,CACxC,IAAI,CAAC,oBAAoB,EAAE,IAAI,CAAC,oBAAoB,CAAC,CAAC;QAExD,IAAI,CAAC,sBAAsB;YACzB,IAAI,kBAAkB,CAAC,EAAC,OAAO,EAAE,IAAI,CAAC,gBAAgB,EAAC,CAAC,CAAC;QAE3D,IAAI,CAAC,sBAAsB,CAAC,KAAK,CAAC,IAAI,CAAC,oBAAoB,CAAC,CAAC;QAC7D,IAAI,CAAC,oBAAoB,GAAG,IAAI,OAAO,CAAC,EAAC,IAAI,EAAE,IAAI,CAAC,OAAO,EAAC,CAAC,CAAC;QAE9D,uCAAuC;QACvC,iDAAiD;QAEjD,sBAAsB;QACtB,IAAI,CAAC,4BAA4B,GAAG,IAAI,KAAK,CAAC;YAC5C,KAAK,EAAE,IAAI,CAAC,eAAe;YAC3B,UAAU,EAAE,IAAI,CAAC,UAAU,CAAC,YAAY,EAA0B;YAClE,iBAAiB,EAAE,cAAc,CAAC,IAAI,CAAC,iBAAiB,CAAC,YAAY,EAAE,CAAC;YACxE,eAAe,EAAE,cAAc,CAAC,IAAI,CAAC,eAAe,CAAC,YAAY,EAAE,CAAC;SACrE,CAAC,CAAC;QACH,IAAI,CAAC,4BAA4B,CAAC,KAAK,CAAC,IAAI,CAAC,oBAAoB,CAAC,CAAC;QACnE,IAAI,CAAC,sBAAsB,GAAG,IAAI,KAAK,CAAC;YACtC,KAAK,EAAE,SAAS;YAChB,iBAAiB,EAAE,cAAc,CAAC,IAAI,CAAC,iBAAiB,CAAC,YAAY,EAAE,CAAC;YACxE,eAAe,EAAE,cAAc,CAAC,IAAI,CAAC,eAAe,CAAC,YAAY,EAAE,CAAC;SACrE,CAAC,CAAC;QACH,MAAM,iBAAiB,GAAG,IAAI,CAAC,oBAAoB,CAAC,KAAK,EAAE,CAAC;QAC5D,iBAAiB,CAAC,iBAAiB,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,eAAe,CAAC;QACvE,IAAI,CAAC,sBAAsB,CAAC,KAAK,CAAC,iBAAiB,CAAC,CAAC;QACrD,IAAI,CAAC,oBAAoB;YACvB,IAAI,kBAAkB,CAAC,EAAC,OAAO,EAAE,IAAI,CAAC,gBAAgB,EAAC,CAAC,CAAC;QAC3D,IAAI,CAAC,oBAAoB,CAAC,KAAK,CAAC,IAAI,CAAC,oBAAoB,CAAC,CAAC;QAC3D,IAAI,CAAC,kBAAkB,GAAG,IAAI,OAAO,CAAC,EAAC,IAAI,EAAE,IAAI,CAAC,OAAO,EAAC,CAAC,CAAC;QAC5D,sCAAsC;QACtC,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC;IACpB,CAAC;IAEQ,KAAK,CACV,eAAsC,EACtC,MAAkC;QACpC,IAAI,CAAC,IAAI,CAAC,KAAK,EAAE;YACf,MAAM,oBAAoB,GAAG,eAAe,CAAC,KAAK,CAAC;YACnD,MAAM,oBAAoB,GACxB,MAAM,IAAI,MAAM,CAAC,eAAe,CAAC,CAAC,CAAC,MAAM,CAAC,eAAe,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,CAAC;YACzE,IAAI,CAAC,KAAK,CAAC,CAAC,oBAAoB,EAAE,oBAAoB,CAAC,CAAC,CAAC;SAC1D;QACD,OAAO,KAAK,CAAC,KAAK,CAAC,eAAe,EAAE,MAAM,CAA0B,CAAC;IACvE,CAAC;IAEQ,IAAI,CACT,eAAuB,EAAE,MAAiC;QAC5D,OAAO,IAAI,CAAC,mBAAmB,CAAC,eAAe,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;IAC9D,CAAC;IAED;;;;;;;;;;OAUG;IACH,mBAAmB,CACjB,eAAuB,EAAE,MAAiC;QAE1D,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,kBAAkB,GAAG,MAAM,CAAC,eAAe,IAAI,IAAI,CAAC;YAC1D,MAAM,iBAAiB,GAAG,IAAI,CAAC,uBAAuB,IAAI,IAAI,CAAC;YAE/D,IAAI,CAAC,iBAAiB,IAAI,kBAAkB,EAAE;gBAC5C,MAAM,IAAI,UAAU,CAClB,8DAA8D;oBAC9D,oEAAoE;oBACpE,4CAA4C;oBAC5C,4DAA4D;oBAC5D,yDAAyD;oBACzD,iDAAiD,CAClD,CAAC;aACH;iBAAM,IAAI,iBAAiB,IAAI,CAAC,kBAAkB,EAAE;gBACnD,MAAM,IAAI,UAAU,CAClB,kEAAkE;oBAClE,gEAAgE;oBAChE,4CAA4C;oBAC5C,2DAA2D;oBAC3D,sDAAsD;oBACtD,sCAAsC,CACvC,CAAC;aACH;YAED,MAAM,qBAAqB,GAAG,MAAM,CAAC,kBAAkB,IAAI,IAAI,CAAC;YAChE,MAAM,sBAAsB,GAAG,MAAM,CAAC,mBAAmB,IAAI,IAAI,CAAC;YAClE,IAAI,iBAAiB,IAAI,CACvB,qBAAqB,KAAK,sBAAsB,CACjD,EAAE;gBACD,MAAM,IAAI,UAAU,CAClB,oEAAoE;oBACpE,iEAAiE;oBACjE,mEAAmE;oBACnE,yDAAyD;oBACzD,sBAAsB,MAAM,CAAC,kBAAkB,IAAI;oBACnD,uBAAuB,MAAM,CAAC,mBAAmB,GAAG,CACrD,CAAC;aACH;YAED,MAAM,iBAAiB,GAAG,IAAI,CAAC,wBAAwB,CACrD,eAAe,EACf,MAAM,CAAC,kBAA4B,EACnC,MAAM,CAAC,oBAAoB,EAC3B,MAAM,CAAC,aAAa,EACpB,MAAM,CAAC,kBAAkB,EACzB,MAAM,CAAC,6BAA6B,CACrC,CAAC;YAEF,IAAI,CAAC,GAAG,eAAe,CAAC,CAAC,uBAAuB;YAChD,IAAI,kBAAkB,GAAG,MAAM,CAAC,kBAAkB,CAAC;YAEnD,wBAAwB;YACxB,IAAI,QAAQ,GAAG,CAAC,CAAC;YACjB,IAAI,IAAI,CAAC,cAAc,EAAE;gBACvB,CAAC,GAAG,IAAI,CAAC,sBAAsB,CAAC,KAAK,CAAC,CAAC,CAAW,CAAC;aACpD;YACD,CAAC,CAAC,EAAE,kBAAkB,CAAC,GAAG,IAAI,CAAC,kBAAkB,CAAC,kBAAkB,CAClE,CAAC,EACD;gBACE,KAAK,EAAE,CAAC;gBACR,aAAa,EAAE,iBAAiB;gBAChC,KAAK,EAAE,kBAAkB;gBACzB,gBAAgB,EAAE,MAAM,CAAC,6BAA6B;aACvD,CACF,CAAC;YACF,CAAC,GAAG,IAAI,CAAC,oBAAoB,CAAC,KAAK,CAAC,CAAC,CAAW,CAAC;YACjD,CAAC,GAAG,GAAG,CAAC,CAAC,EAAE,QAAQ,CAAC,CAAC;YACrB,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE;gBACxB,CAAC,GAAG,IAAI,CAAC,sBAAsB,CAAC,KAAK,CAAC,CAAC,CAAW,CAAC;aACpD;YAED,+BAA+B;YAC/B,yEAAyE;YAEzE,qBAAqB;YACrB,QAAQ,GAAG,CAAC,CAAC;YACb,IAAI,IAAI,CAAC,cAAc,EAAE;gBACvB,CAAC,GAAG,IAAI,CAAC,sBAAsB,CAAC,KAAK,CAAC,CAAC,CAAW,CAAC;aACpD;YACD,CAAC,GAAG,IAAI,CAAC,4BAA4B,CAAC,KAAK,CAAC,CAAC,CAAW,CAAC;YACzD,CAAC,GAAG,IAAI,CAAC,sBAAsB,CAAC,KAAK,CAAC,CAAC,CAAW,CAAC;YACnD,CAAC,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,CAAC,CAAW,CAAC;YAC/C,CAAC,GAAG,GAAG,CAAC,CAAC,EAAE,QAAQ,CAAC,CAAC;YACrB,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE;gBACxB,CAAC,GAAG,IAAI,CAAC,sBAAsB,CAAC,KAAK,CAAC,CAAC,CAAW,CAAC;aACpD;YAED,IAAI,kBAAkB,IAAI,IAAI,EAAE;gBAC9B,IAAI,iBAAiB,EAAE;oBACrB,OAAO,CAAC,CAAC,EAAE,kBAAkB,EAAE,MAAM,CAAC,mBAAmB,CAAC,CAAC;iBAC5D;qBAAM;oBACL,OAAO,CAAC,CAAC,EAAE,kBAAkB,EAAE,IAAI,CAAC,CAAC;iBACtC;aACF;YACD,OAAO,CAAC,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,CAAC;QACzB,CAAC,CAAC,CAAC;IACL,CAAC;IAEO,wBAAwB,CAC9B,eAAuB,EACvB,kBAA0B,EAC1B,oBAA4B,EAC5B,aAAsB,EACtB,kBAA0B,EAC1B,6BAAqC;QAErC,MAAM,WAAW,GAAG,4BAA4B,CAC9C,eAAe,EAAE,kBAAkB,EAAE,oBAAoB,CAAC,CAAC;QAC7D,IAAG,aAAa,EAAE;YAChB,MAAM,SAAS,GAAG,eAAe,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC3C,IAAI,WAAW,GAAG,eAAe,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC3C,MAAM,YAAY,GAAG,eAAe,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC9C,gEAAgE;YAChE,6DAA6D;YAC7D,yEAAyE;YACzE,IAAG,kBAAkB,IAAI,IAAI,EAAE;gBAC7B,WAAW,GAAG,kBAAkB,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;aAC3C;YAED,MAAM,UAAU,GAAG,iBAAiB,CAClC,SAAS,EACT,WAAW,EACX,YAAY,EACZ,6BAA6B,aAA7B,6BAA6B,cAA7B,6BAA6B,GAAI,CAAC,CACnC,CAAC;YACF,OAAO,WAAW,IAAI,IAAI,CAAC,CAAC,CAAC,WAAW,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC;SAC3E;QACD,OAAO,WAAW,CAAC;IACrB,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAAG;YACb,iBAAiB,EAAE,IAAI,CAAC,eAAe;YACvC,UAAU,EAAE,IAAI,CAAC,QAAQ;YACzB,SAAS,EAAE,IAAI,CAAC,OAAO;YACvB,YAAY,EAAE,mBAAmB,CAAC,IAAI,CAAC,UAAU,CAAC;YAClD,kBAAkB,EAAE,IAAI,CAAC,gBAAgB;YACzC,mBAAmB,EAAE,oBAAoB,CAAC,IAAI,CAAC,iBAAiB,CAAC;YACjE,iBAAiB,EAAE,oBAAoB,CAAC,IAAI,CAAC,eAAe,CAAC;YAC7D,gBAAgB,EAAE,IAAI,CAAC,cAAc;YACrC,sBAAsB,EAAE,IAAI,CAAC,oBAAoB;YACjD,sBAAsB,EAAE,IAAI,CAAC,oBAAoB;SAClD,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;IAEQ,kBAAkB,CAAC,oBAA2B;QACrD,OAAO,oBAAoB,CAAC;IAC9B,CAAC;;AA7RD,kBAAkB;AACF,4BAAS,GAAG,oBAAoB,CAAC;SAFtC,kBAAkB;AAgS/B,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Transformer decoder block implementation based on TFJS `Layer`.\n */\n\n/* Original source: keras_nlp/layers/modeling/transformer_decoder.py */\nimport { Tensor, add, serialization, tidy } from '@tensorflow/tfjs-core';\n\nimport { Activation, getActivation, serializeActivation } from '../../../activations';\nimport { Layer, LayerArgs, SymbolicTensor, } from '../../../engine/topology';\nimport { ValueError } from '../../../errors';\nimport { Initializer, InitializerIdentifier, getInitializer, serializeInitializer } from '../../../initializers';\nimport { ActivationIdentifier } from '../../../keras_format/activation_config';\nimport { Shape } from '../../../keras_format/common';\nimport { Dense, Dropout } from '../../core';\nimport { LayerNormalization } from '../../normalization';\n\nimport { CachedMultiHeadAttention } from './cached_multihead_attention';\nimport { computeCausalMask, mergePaddingAndAttentionMask } from './transformer_layer_utils';\n\nexport declare interface TransformerDecoderArgs extends LayerArgs {\n  /**\n   * Integer. The hidden size of feedforward network.\n   */\n  intermediateDim: number;\n\n  /**\n   * Integer. The number of heads in MultiHeadAttention.\n   */\n  numHeads: number;\n\n  /**\n   * The dropout value, shared by MultiHeadAttention and feedforward network.\n   * Defaults to `0.`.\n   */\n  dropout?: number;\n\n  /**\n   * The activation function of feedforward network.\n   * Defaults to `\"relu\"`.\n   */\n  activation?: Activation|ActivationIdentifier;\n\n  /**\n   * The eps value in layer normalization components.\n   * Defaults to `1e-5`.\n   */\n  layerNormEpsilon?: number;\n\n  /**\n   * The kernel initializer for the dense and multiheaded attention layers.\n   * Defaults to `\"glorotUniform\"`.\n   */\n  kernelInitializer?: Initializer|InitializerIdentifier;\n\n  /**\n   * The bias initializer for the dense and multiheaded attention layers.\n   * Defaults to `\"zeros\"`.\n   */\n  biasInitializer?: Initializer|InitializerIdentifier;\n\n  /**\n   * If true, the inputs to the attention layer(s) and the intermediate dense\n   * layer are normalized (similar to GPT-2). If set to false, outputs of\n   * attention layer and intermediate dense layer are normalized\n   * (similar to BERT).\n   * Defaults to `false`.\n   */\n  normalizeFirst?: boolean;\n}\n\nexport declare interface TransformerDecoderOptions {\n  /**\n   * decoderSequence: The decode input sequence.\n   */\n\n  /**\n   * The encoder input sequence. For decoder only models (like GPT2), this\n   * should be left `null`. Once the model is called without an encoderSequence,\n   * you cannot call it again with encoderSequence.\n   */\n  encoderSequence?: Tensor|SymbolicTensor;\n\n  /**\n   * A boolean Tensor, the padding mask of decoder sequence, must be of shape\n   * `[batchSize, decoderSequenceLength]`.\n   */\n  decoderPaddingMask?: Tensor|SymbolicTensor;\n\n  /**\n   * A boolean Tensor. Customized decoder sequence mask, must be of shape\n   * `[batchSize, decoderSequenceLength, decoderSequenceLength]`.\n   */\n  decoderAttentionMask?: Tensor;\n\n  /**\n   * A boolean Tensor, the padding mask of encoder sequence, must be of shape\n   * `[batchSize, encoderSequenceLength]`.\n   */\n  encoderPaddingMask?: Tensor;\n\n  /**\n   * A boolean Tensor. Customized encoder sequence mask, must be of shape\n   * `[batchSize, encoderSequenceLength, encoderSequenceLength]`.\n   */\n  encoderAttentionMask?: Tensor;\n\n  /**\n   * A dense float Tensor. The cache of key/values pairs in the self-attention\n   * layer. Has shape `[batchSize, 2, maxSeqLen, numHeads, keyDims]`.\n   */\n  selfAttentionCache?: Tensor;\n\n  /**\n   * Integer or Integer Tensor. The index at which to update the\n   * `selfAttentionCache`. Usually, this is the index of the current token\n   * being processed during decoding.\n   */\n  selfAttentionCacheUpdateIndex?: number;\n\n  /**\n   * A dense float Tensor. The cache of key/value pairs in the cross-attention\n   * layer. Has shape `[batchSize, 2, S, numHeads, keyDims]`.\n   */\n  crossAttentionCache?: Tensor;\n\n  /**\n   * Integer or Integer Tensor. The index at which to update the\n   * `crossAttentionCache`. Usually, this is either `0` (compute the entire\n   * `crossAttentionCache`), or `null` (reuse a previously computed\n   * `crossAttentionCache`).\n   */\n  crossAttentionCacheUpdateIndex?: number;\n\n  /**\n   * If true, a causal mask (masking out future input) is applied on the decoder\n   * sequence.\n   * Defaults to `true`.\n   */\n  useCausalMask?: boolean;\n}\n\n/**\n * Transformer decoder.\n *\n * This class follows the architecture of the transformer decoder layer in the\n * paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users\n * can instantiate multiple instances of this class to stack up a decoder.\n *\n * By default, this layer will apply a causal mask to the decoder attention\n * layer. This layer will correctly compute an attention mask from an implicit\n * padding mask (for example, by passing `maskZero=true` to a\n * `tf.layers.embedding` layer). See the Masking and Padding\n * [guide](https://keras.io/guides/understanding_masking_and_padding/)\n * for more details.\n *\n * This layer can be called with either one or two inputs. The number of inputs\n * must be consistent across all calls. The options are as follows:\n *    `layer.call(decoderSequence)`: no cross-attention will be built into the\n *         decoder block. This is useful when building a \"decoder-only\"\n *         transformer such as GPT-2.\n *    `layer.call(decoderSequence, {encoderSequence})`: cross-attention will be\n *         built into the decoder block. This is useful when building an\n *         \"encoder-decoder\" transformer, such as the original transformer\n *         model described in Attention is All You Need.\n *\n * Examples:\n * ```js\n * // Create a single transformer decoder layer.\n * const decoder = new TransformerDecoder({intermediateDim: 64, numHeads: 8});\n *\n * // Create a simple model containing the decoder.\n * const decoderInput = tf.input({shape: [10, 64]});\n * const encoderInput = tf.input({shape: {[10, 64]});\n * const output = decoder.call(decoderInput, {encoderInput});\n * const model = tf.model({\n *     inputs: [decoderInput, encoderInput],\n *     outputs: output,\n * );\n *\n * // Call decoder on the inputs.\n * const decoderInputData = tf.randomUniform([2, 10, 64]);\n * const encoderInputData = tf.randomUniform([2, 10, 64]);\n * const decoderOutput = model.predict([decoderInputData, encoderInputData]);\n * ```\n *\n * References:\n *  - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)\n */\nexport class TransformerDecoder extends Layer {\n  /** @nocollapse */\n  static readonly className = 'TransformerDecoder';\n\n  protected intermediateDim: number;\n  protected numHeads: number;\n  protected dropout: number;\n  protected activation: Activation;\n  protected layerNormEpsilon: number;\n  protected kernelInitializer: Initializer;\n  protected biasInitializer: Initializer;\n  protected normalizeFirst: boolean;\n  protected decoderSequenceShape: Shape;\n  protected encoderSequenceShape: Shape;\n\n  protected selfAttentionLayer: CachedMultiHeadAttention;\n  protected selfAttentionLayernorm: LayerNormalization;\n  protected selfAttentionDropout: Dropout;\n\n  protected selfCrossAttentionLayer: CachedMultiHeadAttention;\n  protected selfCrossAttentionLayernorm: LayerNormalization;\n  protected selfCrossAttentionDropout: Dropout;\n\n  protected feedforwardIntermediateDense: Dense;\n  protected feedforwardOutputDense: Dense;\n  protected feedforwardLayernorm: LayerNormalization;\n  protected feedforwardDropout: Dropout;\n\n  constructor(args: TransformerDecoderArgs) {\n    super(args);\n    this.intermediateDim = args.intermediateDim;\n    this.numHeads = args.numHeads;\n    this.dropout = args.dropout ?? 0;\n    this.activation = getActivation(args.activation ?? 'relu');\n    this.layerNormEpsilon = args.layerNormEpsilon ?? 1e-05;\n    this.kernelInitializer =\n      getInitializer(args.kernelInitializer ?? 'glorotUniform');\n    this.biasInitializer = getInitializer(args.biasInitializer ?? 'zeros');\n    this.normalizeFirst = args.normalizeFirst ?? false;\n  }\n\n  /**\n   *\n   * @param inputShape decoderSequenceShape or\n   *  [decoderSequenceShape, encoderSequenceShape]\n   */\n  override build(inputShape: Shape|[Shape, Shape]): void {\n    if (Array.isArray(inputShape[0])) {\n      // `inputShape` is of type [Shape, Shape].\n      [this.decoderSequenceShape, this.encoderSequenceShape] =\n        inputShape as [Shape, Shape];\n    } else {\n      this.decoderSequenceShape = inputShape as Shape;\n    }\n    // Infer the dimension of our hidden feature size from the build shape.\n    const hiddenDim =\n      this.decoderSequenceShape[this.decoderSequenceShape.length - 1];\n    // Attention head size is `hiddenDim` over the number of heads.\n    const headDim = Math.floor(hiddenDim / this.numHeads);\n\n    // Self attention layers.\n    this.selfAttentionLayer = new CachedMultiHeadAttention({\n      numHeads: this.numHeads,\n      keyDim: headDim,\n      dropout: this.dropout,\n      kernelInitializer: getInitializer(this.kernelInitializer.getClassName()),\n      biasInitializer: getInitializer(this.biasInitializer.getClassName()),\n    });\n\n    this.selfAttentionLayer.buildFromSignature(\n      this.decoderSequenceShape, this.decoderSequenceShape);\n\n    this.selfAttentionLayernorm =\n      new LayerNormalization({epsilon: this.layerNormEpsilon});\n\n    this.selfAttentionLayernorm.build(this.decoderSequenceShape);\n    this.selfAttentionDropout = new Dropout({rate: this.dropout});\n\n    // Cross attention layers are optional.\n    // TODO(pforderique): Add cross attention layers.\n\n    // Feedforward layers.\n    this.feedforwardIntermediateDense = new Dense({\n      units: this.intermediateDim,\n      activation: this.activation.getClassName() as ActivationIdentifier,\n      kernelInitializer: getInitializer(this.kernelInitializer.getClassName()),\n      biasInitializer: getInitializer(this.biasInitializer.getClassName()),\n    });\n    this.feedforwardIntermediateDense.build(this.decoderSequenceShape);\n    this.feedforwardOutputDense = new Dense({\n      units: hiddenDim,\n      kernelInitializer: getInitializer(this.kernelInitializer.getClassName()),\n      biasInitializer: getInitializer(this.biasInitializer.getClassName()),\n    });\n    const intermediateShape = this.decoderSequenceShape.slice();\n    intermediateShape[intermediateShape.length - 1] = this.intermediateDim;\n    this.feedforwardOutputDense.build(intermediateShape);\n    this.feedforwardLayernorm =\n      new LayerNormalization({epsilon: this.layerNormEpsilon});\n    this.feedforwardLayernorm.build(this.decoderSequenceShape);\n    this.feedforwardDropout = new Dropout({rate: this.dropout});\n    // Create layers based on input shape.\n    this.built = true;\n  }\n\n  override apply(\n      decoderSequence: Tensor|SymbolicTensor,\n      kwargs?: TransformerDecoderOptions): Tensor|SymbolicTensor {\n    if (!this.built) {\n      const decoderSequenceShape = decoderSequence.shape;\n      const encoderSequenceShape =\n        kwargs && kwargs.encoderSequence ? kwargs.encoderSequence.shape : null;\n      this.build([decoderSequenceShape, encoderSequenceShape]);\n    }\n    return super.apply(decoderSequence, kwargs) as Tensor|SymbolicTensor;\n  }\n\n  override call(\n      decoderSequence: Tensor, kwargs: TransformerDecoderOptions): Tensor {\n    return this.callAndReturnCaches(decoderSequence, kwargs)[0];\n  }\n\n  /**\n   * Forward pass of the TransformerDecoder.\n   *\n   * @returns One of three things, depending on call arguments:\n   *   - `[outputs, null, null]`, if `selfAttentionCache` is `null`.\n   *   - `[outputs, selfAttentionCache, null]`, if `selfAttentionCache` is\n   *     set and the layer has no cross-attention.\n   *   - `[outputs, selfAttentionCache, crossAttentionCache]`, if\n   *     `selfAttentionCache` and `crossAttentionCache` are set and\n   *     the layer has cross-attention.\n   */\n  callAndReturnCaches(\n    decoderSequence: Tensor, kwargs: TransformerDecoderOptions\n  ): [Tensor, Tensor, Tensor] {\n    return tidy(() => {\n      const hasEncoderSequence = kwargs.encoderSequence != null;\n      const hasCrossAttention = this.selfCrossAttentionLayer != null;\n\n      if (!hasCrossAttention && hasEncoderSequence) {\n        throw new ValueError(\n          'The number of call arguments to `TransformerDecoder` should ' +\n          'not change. Use `layer.apply(decoderSequence, {encoderSequence})` ' +\n          'to build a layer with cross attention, or ' +\n          '`layer.apply (decoderSequence)` to build a layer without. ' +\n          'This layer has been built without cross attention, but ' +\n          'you are trying to call it with encoderSequence.'\n        );\n      } else if (hasCrossAttention && !hasEncoderSequence) {\n        throw new ValueError(\n          'The number of call arguments to `TransformerDecoder` should not ' +\n          'change. Use `layer.apply(decoderSequence, {encoderSequence})` ' +\n          'to build a layer with cross attention, or ' +\n          '`layer.apply(decoderSequence)` to build a layer without. ' +\n          'This layer has been built with cross attention, but ' +\n          'you did not provide encoderSequence.'\n        );\n      }\n\n      const hasSelfAttentionCache = kwargs.selfAttentionCache != null;\n      const hasCrossAttentionCache = kwargs.crossAttentionCache != null;\n      if (hasCrossAttention && (\n        hasSelfAttentionCache !== hasCrossAttentionCache\n      )) {\n        throw new ValueError(\n          'When calling `TransformerDecoder` with cross-attention (with both ' +\n          '`encoderSequence` and `decoderSequence`), `selfAttentionCache` ' +\n          'and `crossAttentionCache` should both be set or both be `null`.  ' +\n          'One cannot be `null` while the other is not. Received: ' +\n          `selfAttentionCache=${kwargs.selfAttentionCache}, ` +\n          `crossAttentionCache=${kwargs.crossAttentionCache}.`\n        );\n      }\n\n      const selfAttentionMask = this.computeSelfAttentionMask(\n        decoderSequence,\n        kwargs.decoderPaddingMask as Tensor,\n        kwargs.decoderAttentionMask,\n        kwargs.useCausalMask,\n        kwargs.selfAttentionCache,\n        kwargs.selfAttentionCacheUpdateIndex,\n      );\n\n      let x = decoderSequence; // Intermediate result.\n      let selfAttentionCache = kwargs.selfAttentionCache;\n\n      // Self attention block.\n      let residual = x;\n      if (this.normalizeFirst) {\n        x = this.selfAttentionLayernorm.apply(x) as Tensor;\n      }\n      [x, selfAttentionCache] = this.selfAttentionLayer.callAndReturnCache(\n        x,\n        {\n          value: x,\n          attentionMask: selfAttentionMask,\n          cache: selfAttentionCache,\n          cacheUpdateIndex: kwargs.selfAttentionCacheUpdateIndex,\n        }\n      );\n      x = this.selfAttentionDropout.apply(x) as Tensor;\n      x = add(x, residual);\n      if (!this.normalizeFirst) {\n        x = this.selfAttentionLayernorm.apply(x) as Tensor;\n      }\n\n      // Cross attention is optional.\n      // TODO(pforderique): Add cross attention logic for encoder-decoder arch.\n\n      // Feedforward block.\n      residual = x;\n      if (this.normalizeFirst) {\n        x = this.selfAttentionLayernorm.apply(x) as Tensor;\n      }\n      x = this.feedforwardIntermediateDense.apply(x) as Tensor;\n      x = this.feedforwardOutputDense.apply(x) as Tensor;\n      x = this.feedforwardDropout.apply(x) as Tensor;\n      x = add(x, residual);\n      if (!this.normalizeFirst) {\n        x = this.selfAttentionLayernorm.apply(x) as Tensor;\n      }\n\n      if (selfAttentionCache != null) {\n        if (hasCrossAttention) {\n          return [x, selfAttentionCache, kwargs.crossAttentionCache];\n        } else {\n          return [x, selfAttentionCache, null];\n        }\n      }\n      return [x, null, null];\n    });\n  }\n\n  private computeSelfAttentionMask(\n    decoderSequence: Tensor,\n    decoderPaddingMask: Tensor,\n    decoderAttentionMask: Tensor,\n    useCasualMask: boolean,\n    selfAttentionCache: Tensor,\n    selfAttentionCacheUpdateIndex: number\n  ): Tensor {\n    const decoderMask = mergePaddingAndAttentionMask(\n      decoderSequence, decoderPaddingMask, decoderAttentionMask);\n    if(useCasualMask) {\n      const batchSize = decoderSequence.shape[0];\n      let inputLength = decoderSequence.shape[1];\n      const outputLength = decoderSequence.shape[1];\n      // We need to handle a rectangular causal mask when doing cached\n      // decoding. For generative inference, `decoderSequence` will\n      // generally be length 1, and `cache` will be the full generation length.\n      if(selfAttentionCache != null) {\n        inputLength = selfAttentionCache.shape[2];\n      }\n\n      const causalMask = computeCausalMask(\n        batchSize,\n        inputLength,\n        outputLength,\n        selfAttentionCacheUpdateIndex ?? 0\n      );\n      return decoderMask != null ? decoderMask.minimum(causalMask) : causalMask;\n    }\n    return decoderMask;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config = {\n      'intermediateDim': this.intermediateDim,\n      'numHeads': this.numHeads,\n      'dropout': this.dropout,\n      'activation': serializeActivation(this.activation),\n      'layerNormEpsilon': this.layerNormEpsilon,\n      'kernelInitializer': serializeInitializer(this.kernelInitializer),\n      'biasInitializer': serializeInitializer(this.biasInitializer),\n      'normalizeFirst': this.normalizeFirst,\n      'decoderSequenceShape': this.decoderSequenceShape,\n      'encoderSequenceShape': this.encoderSequenceShape,\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n\n  override computeOutputShape(decoderSequenceShape: Shape): Shape {\n    return decoderSequenceShape;\n  }\n}\nserialization.registerClass(TransformerDecoder);\n"]}

@@ -24,3 +24,6 @@ /**

import { LayersModel } from '../../../engine/training';
import { Embedding } from '../../embeddings';
export declare class Backbone extends LayersModel {
/** @nocollapse */
static className: string;
constructor(args: ContainerArgs);

@@ -30,5 +33,5 @@ /**

*/
get tokenEmbedding(): void;
get tokenEmbedding(): Embedding;
getConfig(): serialization.ConfigDict;
static fromConfig<T extends serialization.Serializable>(cls: serialization.SerializableConstructor<T>, config: serialization.ConfigDict): T;
}

@@ -24,3 +24,3 @@ /**

import { NotImplementedError } from '../../../errors';
export class Backbone extends LayersModel {
class Backbone extends LayersModel {
constructor(args) {

@@ -45,3 +45,6 @@ super(args);

}
/** @nocollapse */
Backbone.className = 'Backbone';
export { Backbone };
serialization.registerClass(Backbone);
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYmFja2JvbmUuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWxheWVycy9zcmMvbGF5ZXJzL25scC9tb2RlbHMvYmFja2JvbmUudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUg7O0dBRUc7QUFFSCxtREFBbUQ7QUFDbkQsT0FBTyxFQUFFLGFBQWEsRUFBRSxNQUFNLHVCQUF1QixDQUFDO0FBR3RELE9BQU8sRUFBRSxXQUFXLEVBQUUsTUFBTSwwQkFBMEIsQ0FBQztBQUN2RCxPQUFPLEVBQUUsbUJBQW1CLEVBQUUsTUFBTSxpQkFBaUIsQ0FBQztBQUV0RCxNQUFNLE9BQU8sUUFBUyxTQUFRLFdBQVc7SUFFdkMsWUFBWSxJQUFtQjtRQUM3QixLQUFLLENBQUMsSUFBSSxDQUFDLENBQUM7SUFDZCxDQUFDO0lBRUQ7O09BRUc7SUFDSCxJQUFJLGNBQWM7UUFDaEIsTUFBTSxJQUFJLG1CQUFtQixFQUFFLENBQUM7SUFDbEMsQ0FBQztJQUVRLFNBQVM7UUFDaEIsT0FBTztZQUNMLElBQUksRUFBRSxJQUFJLENBQUMsSUFBSTtZQUNmLFNBQVMsRUFBRSxJQUFJLENBQUMsU0FBUztTQUMxQixDQUFDO0lBQ0osQ0FBQztJQUVELE1BQU0sQ0FBVSxVQUFVLENBQ3hCLEdBQTZDLEVBQzdDLE1BQWdDO1FBRWhDLE9BQU8sSUFBSSxHQUFHLENBQUMsTUFBTSxDQUFDLENBQUM7SUFDekIsQ0FBQztDQUNGO0FBQ0QsYUFBYSxDQUFDLGFBQWEsQ0FBQyxRQUFRLENBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIzIEdvb2dsZSBMTEMuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuLyoqXG4gKiAgQmFzZSBjbGFzcyBmb3IgQmFja2JvbmUgbW9kZWxzLlxuICovXG5cbi8qIE9yaWdpbmFsIHNvdXJjZToga2VyYXNfbmxwL21vZGVscy9iYWNrYm9uZS5weSAqL1xuaW1wb3J0IHsgc2VyaWFsaXphdGlvbiB9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7IENvbnRhaW5lckFyZ3MgfSBmcm9tICcuLi8uLi8uLi9lbmdpbmUvY29udGFpbmVyJztcbmltcG9ydCB7IExheWVyc01vZGVsIH0gZnJvbSAnLi4vLi4vLi4vZW5naW5lL3RyYWluaW5nJztcbmltcG9ydCB7IE5vdEltcGxlbWVudGVkRXJyb3IgfSBmcm9tICcuLi8uLi8uLi9lcnJvcnMnO1xuXG5leHBvcnQgY2xhc3MgQmFja2JvbmUgZXh0ZW5kcyBMYXllcnNNb2RlbCB7XG5cbiAgY29uc3RydWN0b3IoYXJnczogQ29udGFpbmVyQXJncykge1xuICAgIHN1cGVyKGFyZ3MpO1xuICB9XG5cbiAgLyoqXG4gICAqIEEgYHRmLmxheWVycy5lbWJlZGRpbmdgIGluc3RhbmNlIGZvciBlbWJlZGRpbmcgdG9rZW4gaWRzLlxuICAgKi9cbiAgZ2V0IHRva2VuRW1iZWRkaW5nKCkge1xuICAgIHRocm93IG5ldyBOb3RJbXBsZW1lbnRlZEVycm9yKCk7XG4gIH1cblxuICBvdmVycmlkZSBnZXRDb25maWcoKTogc2VyaWFsaXphdGlvbi5Db25maWdEaWN0IHtcbiAgICByZXR1cm4ge1xuICAgICAgbmFtZTogdGhpcy5uYW1lLFxuICAgICAgdHJhaW5hYmxlOiB0aGlzLnRyYWluYWJsZSxcbiAgICB9O1xuICB9XG5cbiAgc3RhdGljIG92ZXJyaWRlIGZyb21Db25maWc8VCBleHRlbmRzIHNlcmlhbGl6YXRpb24uU2VyaWFsaXphYmxlPihcbiAgICBjbHM6IHNlcmlhbGl6YXRpb24uU2VyaWFsaXphYmxlQ29uc3RydWN0b3I8VD4sXG4gICAgY29uZmlnOiBzZXJpYWxpemF0aW9uLkNvbmZpZ0RpY3QpOiBUIHtcblxuICAgIHJldHVybiBuZXcgY2xzKGNvbmZpZyk7XG4gIH1cbn1cbnNlcmlhbGl6YXRpb24ucmVnaXN0ZXJDbGFzcyhCYWNrYm9uZSk7XG4iXX0=
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYmFja2JvbmUuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWxheWVycy9zcmMvbGF5ZXJzL25scC9tb2RlbHMvYmFja2JvbmUudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUg7O0dBRUc7QUFFSCxtREFBbUQ7QUFDbkQsT0FBTyxFQUFFLGFBQWEsRUFBRSxNQUFNLHVCQUF1QixDQUFDO0FBR3RELE9BQU8sRUFBRSxXQUFXLEVBQUUsTUFBTSwwQkFBMEIsQ0FBQztBQUN2RCxPQUFPLEVBQUUsbUJBQW1CLEVBQUUsTUFBTSxpQkFBaUIsQ0FBQztBQUd0RCxNQUFhLFFBQVMsU0FBUSxXQUFXO0lBSXZDLFlBQVksSUFBbUI7UUFDN0IsS0FBSyxDQUFDLElBQUksQ0FBQyxDQUFDO0lBQ2QsQ0FBQztJQUVEOztPQUVHO0lBQ0gsSUFBSSxjQUFjO1FBQ2hCLE1BQU0sSUFBSSxtQkFBbUIsRUFBRSxDQUFDO0lBQ2xDLENBQUM7SUFFUSxTQUFTO1FBQ2hCLE9BQU87WUFDTCxJQUFJLEVBQUUsSUFBSSxDQUFDLElBQUk7WUFDZixTQUFTLEVBQUUsSUFBSSxDQUFDLFNBQVM7U0FDMUIsQ0FBQztJQUNKLENBQUM7SUFFRCxNQUFNLENBQVUsVUFBVSxDQUN4QixHQUE2QyxFQUM3QyxNQUFnQztRQUVoQyxPQUFPLElBQUksR0FBRyxDQUFDLE1BQU0sQ0FBQyxDQUFDO0lBQ3pCLENBQUM7O0FBMUJELGtCQUFrQjtBQUNGLGtCQUFTLEdBQUcsVUFBVSxDQUFDO1NBRjVCLFFBQVE7QUE2QnJCLGFBQWEsQ0FBQyxhQUFhLENBQUMsUUFBUSxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMyBHb29nbGUgTExDLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbi8qKlxuICogIEJhc2UgY2xhc3MgZm9yIEJhY2tib25lIG1vZGVscy5cbiAqL1xuXG4vKiBPcmlnaW5hbCBzb3VyY2U6IGtlcmFzX25scC9tb2RlbHMvYmFja2JvbmUucHkgKi9cbmltcG9ydCB7IHNlcmlhbGl6YXRpb24gfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQgeyBDb250YWluZXJBcmdzIH0gZnJvbSAnLi4vLi4vLi4vZW5naW5lL2NvbnRhaW5lcic7XG5pbXBvcnQgeyBMYXllcnNNb2RlbCB9IGZyb20gJy4uLy4uLy4uL2VuZ2luZS90cmFpbmluZyc7XG5pbXBvcnQgeyBOb3RJbXBsZW1lbnRlZEVycm9yIH0gZnJvbSAnLi4vLi4vLi4vZXJyb3JzJztcbmltcG9ydCB7IEVtYmVkZGluZyB9IGZyb20gJy4uLy4uL2VtYmVkZGluZ3MnO1xuXG5leHBvcnQgY2xhc3MgQmFja2JvbmUgZXh0ZW5kcyBMYXllcnNNb2RlbCB7XG4gIC8qKiBAbm9jb2xsYXBzZSAqL1xuICBzdGF0aWMgb3ZlcnJpZGUgY2xhc3NOYW1lID0gJ0JhY2tib25lJztcblxuICBjb25zdHJ1Y3RvcihhcmdzOiBDb250YWluZXJBcmdzKSB7XG4gICAgc3VwZXIoYXJncyk7XG4gIH1cblxuICAvKipcbiAgICogQSBgdGYubGF5ZXJzLmVtYmVkZGluZ2AgaW5zdGFuY2UgZm9yIGVtYmVkZGluZyB0b2tlbiBpZHMuXG4gICAqL1xuICBnZXQgdG9rZW5FbWJlZGRpbmcoKTogRW1iZWRkaW5nIHtcbiAgICB0aHJvdyBuZXcgTm90SW1wbGVtZW50ZWRFcnJvcigpO1xuICB9XG5cbiAgb3ZlcnJpZGUgZ2V0Q29uZmlnKCk6IHNlcmlhbGl6YXRpb24uQ29uZmlnRGljdCB7XG4gICAgcmV0dXJuIHtcbiAgICAgIG5hbWU6IHRoaXMubmFtZSxcbiAgICAgIHRyYWluYWJsZTogdGhpcy50cmFpbmFibGUsXG4gICAgfTtcbiAgfVxuXG4gIHN0YXRpYyBvdmVycmlkZSBmcm9tQ29uZmlnPFQgZXh0ZW5kcyBzZXJpYWxpemF0aW9uLlNlcmlhbGl6YWJsZT4oXG4gICAgY2xzOiBzZXJpYWxpemF0aW9uLlNlcmlhbGl6YWJsZUNvbnN0cnVjdG9yPFQ+LFxuICAgIGNvbmZpZzogc2VyaWFsaXphdGlvbi5Db25maWdEaWN0KTogVCB7XG5cbiAgICByZXR1cm4gbmV3IGNscyhjb25maWcpO1xuICB9XG59XG5zZXJpYWxpemF0aW9uLnJlZ2lzdGVyQ2xhc3MoQmFja2JvbmUpO1xuIl19

@@ -21,6 +21,7 @@ /**

*/
import { Tensor, Tensor2D, serialization } from '@tensorflow/tfjs-core';
import { NamedTensorMap, Tensor, serialization } from '@tensorflow/tfjs-core';
import { LayerArgs } from '../../../../engine/topology';
import { Preprocessor } from '../preprocessor';
import { GPT2Tokenizer } from './gpt2_tokenizer';
import { StartEndPacker } from '../../preprocessing/start_end_packer';
export declare interface GPT2PreprocessorArgs extends LayerArgs {

@@ -63,6 +64,3 @@ /**

}
export declare interface PreprocessorOutputs {
tokenIds: Tensor2D;
paddingMask: Tensor2D;
}
export declare function packXYSampleWeight(x: NamedTensorMap, y?: Tensor, sampleWeight?: Tensor): NamedTensorMap | [NamedTensorMap, Tensor] | [NamedTensorMap, Tensor, Tensor];
/**

@@ -105,6 +103,8 @@ * GPT2 preprocessing layer which tokenizes and packs inputs.

export declare class GPT2Preprocessor extends Preprocessor {
private readonly sequenceLength;
private readonly addStartToken;
private readonly addEndToken;
private readonly packer;
/** @nocollapse */
static className: string;
protected readonly sequenceLength: number;
protected readonly addStartToken: boolean;
protected readonly addEndToken: boolean;
protected readonly packer: StartEndPacker;
constructor(args: GPT2PreprocessorArgs);

@@ -118,4 +118,4 @@ getConfig(): serialization.ConfigDict;

*/
callAndPackArgs(inputs: Tensor | Tensor[], kwargs: GPT2PreprocessorOptions): PreprocessorOutputs | [PreprocessorOutputs, Tensor] | [PreprocessorOutputs, Tensor, Tensor];
callAndPackArgs(inputs: Tensor | Tensor[], kwargs: GPT2PreprocessorOptions): NamedTensorMap | [NamedTensorMap, Tensor] | [NamedTensorMap, Tensor, Tensor];
static tokenizerCls<T extends serialization.Serializable>(cls: serialization.SerializableConstructor<T>): typeof GPT2Tokenizer;
}

@@ -26,3 +26,3 @@ /**

import { ValueError } from '../../../../errors';
function packXYSampleWeight(x, y, sampleWeight) {
export function packXYSampleWeight(x, y, sampleWeight) {
if (y === undefined) {

@@ -74,3 +74,3 @@ return x;

*/
export class GPT2Preprocessor extends Preprocessor {
class GPT2Preprocessor extends Preprocessor {
constructor(args) {

@@ -140,3 +140,6 @@ var _a, _b, _c;

}
/** @nocollapse */
GPT2Preprocessor.className = 'GPT2Preprocessor';
export { GPT2Preprocessor };
serialization.registerClass(GPT2Preprocessor);
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"gpt2_preprocessor.js","sourceRoot":"","sources":["../../../../../../../../../tfjs-layers/src/layers/nlp/models/gpt2/gpt2_preprocessor.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,iEAAiE;AACjE,OAAO,EAAoB,aAAa,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAG9E,OAAO,EAAE,YAAY,EAAE,MAAM,iBAAiB,CAAC;AAC/C,OAAO,EAAE,aAAa,EAAE,MAAM,kBAAkB,CAAC;AACjD,OAAO,EAAE,cAAc,EAAE,MAAM,sCAAsC,CAAC;AACtE,OAAO,EAAE,UAAU,EAAE,MAAM,oBAAoB,CAAC;AAmDhD,SAAS,kBAAkB,CACzB,CAAsB,EAAE,CAAU,EAAE,YAAqB;IAKzD,IAAI,CAAC,KAAK,SAAS,EAAE;QACnB,OAAO,CAAC,CAAC;KACV;SAAM,IAAI,YAAY,KAAK,SAAS,EAAE;QACrC,OAAO,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;KACf;SAAM;QACL,OAAO,CAAC,CAAC,EAAE,CAAC,EAAE,YAAY,CAAC,CAAC;KAC7B;AACH,CAAC;AAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAmCG;AACH,MAAM,OAAO,gBAAiB,SAAQ,YAAY;IAMhD,YAAY,IAA0B;;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,SAAS,CAAC;QAChC,IAAI,CAAC,cAAc,GAAG,MAAA,IAAI,CAAC,cAAc,mCAAI,IAAI,CAAC;QAClD,IAAI,CAAC,aAAa,GAAG,MAAA,IAAI,CAAC,aAAa,mCAAI,IAAI,CAAC;QAChD,IAAI,CAAC,WAAW,GAAG,MAAA,IAAI,CAAC,WAAW,mCAAI,IAAI,CAAC;QAE5C,MAAM,aAAa,GAAG,IAAI,CAAC,SAA0B,CAAC;QACtD,IAAI,CAAC,MAAM,GAAG,IAAI,cAAc,CAAC;YAC/B,UAAU,EAAE,aAAa,CAAC,YAAY;YACtC,QAAQ,EAAE,aAAa,CAAC,UAAU;YAClC,QAAQ,EAAE,aAAa,CAAC,UAAU;YAClC,cAAc,EAAE,IAAI,CAAC,cAAc;SACpC,CAAC,CAAC;IACL,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAAG;YACb,cAAc,EAAE,IAAI,CAAC,cAAc;YACnC,aAAa,EAAE,IAAI,CAAC,aAAa;YACjC,WAAW,EAAE,IAAI,CAAC,WAAW;SAC9B,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;IAEQ,IAAI,CACX,MAAuB,EAAE,MAA+B;QACxD,OAAO,IAAI,CAAC,wBAAwB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC,QAAQ,CAAC;IAChE,CAAC;IAEO,wBAAwB,CAC9B,MAAuB,EACvB,MAA+B;QAE/B,OAAO,IAAI,CAAC,GAAG,EAAE;;YACf,IAAI,MAAM,YAAY,KAAK,EAAE;gBAC3B,IAAI,MAAM,CAAC,MAAM,KAAK,CAAC,EAAE;oBACvB,MAAM,IAAI,UAAU,CAClB,mDAAmD;wBACnD,6BAA6B,MAAM,CAAC,MAAM,qBAAqB;wBAC/D,gEAAgE;wBAChE,6CAA6C,CAC9C,CAAC;iBACH;gBACD,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;aACpB;YAED,MAAM,cAAc,GAAG,MAAA,MAAM,CAAC,cAAc,mCAAI,IAAI,CAAC,cAAc,CAAC;YACpE,MAAM,CAAC,QAAQ,EAAE,WAAW,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,wBAAwB,CAClE,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,MAAM,CAAC,EAC3B;gBACE,cAAc;gBACd,aAAa,EAAE,IAAI,CAAC,aAAa;gBACjC,WAAW,EAAE,IAAI,CAAC,WAAW;aAC9B,CACF,CAAC;YAEF,OAAO;gBACL,QAAQ,EAAE,QAAoB;gBAC9B,WAAW,EAAE,WAAuB;aACrC,CAAC;QACJ,CAAC,CAAC,CAAC;IACL,CAAC;IAED;;;OAGG;IACH,eAAe,CAAC,MAAuB,EAAE,MAA+B;QAItE,MAAM,CAAC,GAAG,IAAI,CAAC,wBAAwB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;QACxD,OAAO,kBAAkB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,EAAE,MAAM,CAAC,YAAY,CAAC,CAAC;IAC9D,CAAC;IAED,MAAM,CAAU,YAAY,CAC1B,GAA6C;QAC7C,OAAO,aAAa,CAAC;IACvB,CAAC;CACF;AACD,aAAa,CAAC,aAAa,CAAC,gBAAgB,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n * GPT-2 preprocessor layer.\n */\n\n/* Original source: keras-nlp/models/gpt2/gpt2_preprocessor.py */\nimport { Tensor, Tensor2D, serialization, tidy } from '@tensorflow/tfjs-core';\n\nimport { LayerArgs } from '../../../../engine/topology';\nimport { Preprocessor } from '../preprocessor';\nimport { GPT2Tokenizer } from './gpt2_tokenizer';\nimport { StartEndPacker } from '../../preprocessing/start_end_packer';\nimport { ValueError } from '../../../../errors';\n\nexport declare interface GPT2PreprocessorArgs extends LayerArgs {\n  /**\n   * A GPT2Tokenizer instance.\n   */\n  tokenizer: GPT2Tokenizer;\n\n  /**\n   * The length of the packed inputs.\n   * Defaults to 1024.\n   */\n  sequenceLength?: number;\n\n  /**\n   * If `true`, the preprocessor will prepend the tokenizer start token to each\n   * input sequence.\n   * Defaults to `true`.\n   */\n  addStartToken?: boolean;\n\n  /**\n   * If `true`, the preprocessor will prepend the tokenizer end token to each\n   * input sequence.\n   * Defaults to `true`.\n   */\n  addEndToken?: boolean;\n}\n\nexport declare interface GPT2PreprocessorOptions {\n  /**\n   * Any label data. Will be passed through unaltered.\n   */\n  y?: Tensor;\n\n  /**\n   * Any label weight data. Will be passed through unaltered.\n   */\n  sampleWeight?: Tensor;\n\n  /**\n   * Pass to override the configured `sequenceLength` of the layer.\n   */\n  sequenceLength?: number;\n}\n\nexport declare interface PreprocessorOutputs {\n  tokenIds: Tensor2D;\n  paddingMask: Tensor2D;\n}\n\nfunction packXYSampleWeight(\n  x: PreprocessorOutputs, y?: Tensor, sampleWeight?: Tensor):\n  PreprocessorOutputs\n  | [PreprocessorOutputs, Tensor]\n  | [PreprocessorOutputs, Tensor, Tensor] {\n\n  if (y === undefined) {\n    return x;\n  } else if (sampleWeight === undefined) {\n    return [x, y];\n  } else {\n    return [x, y, sampleWeight];\n  }\n}\n\n/**\n * GPT2 preprocessing layer which tokenizes and packs inputs.\n *\n * This preprocessing layer will do 2 things:\n *\n * - Tokenize the inputs using the `tokenizer`.\n * - Construct a dictionary with keys `\"tokenIds\"`, `\"paddingMask\"`, that can\n *     be passed directly to a `GPT2Backbone`.\n *\n * The call method of this layer accepts three arguments, `x`, `y`, and\n * `sampleWeight`. `x` can be a string or tensor representing a single\n * segment, a list of strings representing a batch of single segments,\n * or a list of tensors representing multiple segments to be packed together.\n * `y` and `sampleWeight` are both optional, can have any format, and will be\n * passed through unaltered.\n *\n * `GPT2Preprocessor` forces the input to have only one segment, as GPT2 is\n * mainly used for generation tasks. For tasks having multi-segment inputs\n * like \"glue/mnli\", please use a model designed for classification purposes\n * such as BERT or RoBERTa.\n *\n * Examples:\n *\n * Directly calling the layer on data.\n * ```js\n * const features =  ['a quick fox.', 'a fox quick.'];\n * const vocabulary =\n *    new Map([['<|endoftext|>', 0], ['a', 4], ['Ġquick', 5], ['Ġfox', 6]]);\n * const merges =\n *    ['Ġ q', 'u i', 'c k', 'ui ck', 'Ġq uick', 'Ġ f', 'o x', 'Ġf ox'];\n * const tokenizer = GPT2Tokenizer({vocabulary, merges});\n *\n * const preprocessor = GPT2Preprocessor({tokenizer});\n * preprocessor.call(tensor(['the quick brown fox jumped.']))[0].print();\n * ```\n */\nexport class GPT2Preprocessor extends Preprocessor {\n  private readonly sequenceLength: number;\n  private readonly addStartToken: boolean;\n  private readonly addEndToken: boolean;\n  private readonly packer: StartEndPacker;\n\n  constructor(args: GPT2PreprocessorArgs) {\n    super(args);\n    this.tokenizer = args.tokenizer;\n    this.sequenceLength = args.sequenceLength ?? 1024;\n    this.addStartToken = args.addStartToken ?? true;\n    this.addEndToken = args.addEndToken ?? true;\n\n    const gpt2Tokenizer = this.tokenizer as GPT2Tokenizer;\n    this.packer = new StartEndPacker({\n      startValue: gpt2Tokenizer.startTokenId,\n      endValue: gpt2Tokenizer.endTokenId,\n      padValue: gpt2Tokenizer.padTokenId,\n      sequenceLength: this.sequenceLength,\n    });\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config = {\n      sequenceLength: this.sequenceLength,\n      addStartToken: this.addStartToken,\n      addEndToken: this.addEndToken,\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n\n  override call(\n    inputs: Tensor|Tensor[], kwargs: GPT2PreprocessorOptions): Tensor|Tensor[] {\n    return this.callAndReturnPaddingMask(inputs, kwargs).tokenIds;\n  }\n\n  private callAndReturnPaddingMask(\n    inputs: Tensor|Tensor[],\n    kwargs: GPT2PreprocessorOptions\n  ): PreprocessorOutputs {\n    return tidy(() => {\n      if (inputs instanceof Array) {\n        if (inputs.length !== 1) {\n          throw new ValueError(\n            'GPT2 requires each input feature to contain only ' +\n            `one segment, but received ${inputs.length}. If you are using ` +\n            'GPT2 for a multi-segment classification task, please refer to ' +\n            'classification models like BERT or RoBERTa.'\n          );\n        }\n        inputs = inputs[0];\n      }\n\n      const sequenceLength = kwargs.sequenceLength ?? this.sequenceLength;\n      const [tokenIds, paddingMask] = this.packer.callAndReturnPaddingMask(\n        this.tokenizer.call(inputs),\n        {\n          sequenceLength,\n          addStartValue: this.addStartToken,\n          addEndValue: this.addEndToken\n        }\n      );\n\n      return {\n        tokenIds: tokenIds as Tensor2D,\n        paddingMask: paddingMask as Tensor2D\n      };\n    });\n  }\n\n  /**\n   * Calls the layer and returns extra information like the paddingMask used to\n   * pack the sequence, the label data, and the sample weights used.\n   */\n  callAndPackArgs(inputs: Tensor|Tensor[], kwargs: GPT2PreprocessorOptions):\n    PreprocessorOutputs\n    | [PreprocessorOutputs, Tensor]\n    | [PreprocessorOutputs, Tensor, Tensor] {\n    const x = this.callAndReturnPaddingMask(inputs, kwargs);\n    return packXYSampleWeight(x, kwargs.y, kwargs.sampleWeight);\n  }\n\n  static override tokenizerCls<T extends serialization.Serializable>(\n    cls: serialization.SerializableConstructor<T>) {\n    return GPT2Tokenizer;\n  }\n}\nserialization.registerClass(GPT2Preprocessor);\n"]}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"gpt2_preprocessor.js","sourceRoot":"","sources":["../../../../../../../../../tfjs-layers/src/layers/nlp/models/gpt2/gpt2_preprocessor.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,iEAAiE;AACjE,OAAO,EAAoC,aAAa,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAG9F,OAAO,EAAE,YAAY,EAAE,MAAM,iBAAiB,CAAC;AAC/C,OAAO,EAAE,aAAa,EAAE,MAAM,kBAAkB,CAAC;AACjD,OAAO,EAAE,cAAc,EAAE,MAAM,sCAAsC,CAAC;AACtE,OAAO,EAAE,UAAU,EAAE,MAAM,oBAAoB,CAAC;AA8ChD,MAAM,UAAU,kBAAkB,CAChC,CAAiB,EAAE,CAAU,EAAE,YAAqB;IAKpD,IAAI,CAAC,KAAK,SAAS,EAAE;QACnB,OAAO,CAAC,CAAC;KACV;SAAM,IAAI,YAAY,KAAK,SAAS,EAAE;QACrC,OAAO,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;KACf;SAAM;QACL,OAAO,CAAC,CAAC,EAAE,CAAC,EAAE,YAAY,CAAC,CAAC;KAC7B;AACH,CAAC;AAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAmCG;AACH,MAAa,gBAAiB,SAAQ,YAAY;IAShD,YAAY,IAA0B;;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,SAAS,CAAC;QAChC,IAAI,CAAC,cAAc,GAAG,MAAA,IAAI,CAAC,cAAc,mCAAI,IAAI,CAAC;QAClD,IAAI,CAAC,aAAa,GAAG,MAAA,IAAI,CAAC,aAAa,mCAAI,IAAI,CAAC;QAChD,IAAI,CAAC,WAAW,GAAG,MAAA,IAAI,CAAC,WAAW,mCAAI,IAAI,CAAC;QAE5C,MAAM,aAAa,GAAG,IAAI,CAAC,SAA0B,CAAC;QACtD,IAAI,CAAC,MAAM,GAAG,IAAI,cAAc,CAAC;YAC/B,UAAU,EAAE,aAAa,CAAC,YAAY;YACtC,QAAQ,EAAE,aAAa,CAAC,UAAU;YAClC,QAAQ,EAAE,aAAa,CAAC,UAAU;YAClC,cAAc,EAAE,IAAI,CAAC,cAAc;SACpC,CAAC,CAAC;IACL,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAAG;YACb,cAAc,EAAE,IAAI,CAAC,cAAc;YACnC,aAAa,EAAE,IAAI,CAAC,aAAa;YACjC,WAAW,EAAE,IAAI,CAAC,WAAW;SAC9B,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;IAEQ,IAAI,CACX,MAAuB,EAAE,MAA+B;QACxD,OAAO,IAAI,CAAC,wBAAwB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC,QAAQ,CAAC;IAChE,CAAC;IAEO,wBAAwB,CAC9B,MAAuB,EACvB,MAA+B;QAE/B,OAAO,IAAI,CAAC,GAAG,EAAE;;YACf,IAAI,MAAM,YAAY,KAAK,EAAE;gBAC3B,IAAI,MAAM,CAAC,MAAM,KAAK,CAAC,EAAE;oBACvB,MAAM,IAAI,UAAU,CAClB,mDAAmD;wBACnD,6BAA6B,MAAM,CAAC,MAAM,qBAAqB;wBAC/D,gEAAgE;wBAChE,6CAA6C,CAC9C,CAAC;iBACH;gBACD,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;aACpB;YAED,MAAM,cAAc,GAAG,MAAA,MAAM,CAAC,cAAc,mCAAI,IAAI,CAAC,cAAc,CAAC;YACpE,MAAM,CAAC,QAAQ,EAAE,WAAW,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,wBAAwB,CAClE,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,MAAM,CAAC,EAC3B;gBACE,cAAc;gBACd,aAAa,EAAE,IAAI,CAAC,aAAa;gBACjC,WAAW,EAAE,IAAI,CAAC,WAAW;aAC9B,CACF,CAAC;YAEF,OAAO;gBACL,QAAQ,EAAE,QAAoB;gBAC9B,WAAW,EAAE,WAAuB;aACrC,CAAC;QACJ,CAAC,CAAC,CAAC;IACL,CAAC;IAED;;;OAGG;IACH,eAAe,CAAC,MAAuB,EAAE,MAA+B;QAItE,MAAM,CAAC,GAAG,IAAI,CAAC,wBAAwB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;QACxD,OAAO,kBAAkB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC,EAAE,MAAM,CAAC,YAAY,CAAC,CAAC;IAC9D,CAAC;IAED,MAAM,CAAU,YAAY,CAC1B,GAA6C;QAC7C,OAAO,aAAa,CAAC;IACvB,CAAC;;AAzFD,kBAAkB;AACF,0BAAS,GAAG,kBAAkB,CAAC;SAFpC,gBAAgB;AA4F7B,aAAa,CAAC,aAAa,CAAC,gBAAgB,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n * GPT-2 preprocessor layer.\n */\n\n/* Original source: keras-nlp/models/gpt2/gpt2_preprocessor.py */\nimport { NamedTensorMap, Tensor, Tensor2D, serialization, tidy } from '@tensorflow/tfjs-core';\n\nimport { LayerArgs } from '../../../../engine/topology';\nimport { Preprocessor } from '../preprocessor';\nimport { GPT2Tokenizer } from './gpt2_tokenizer';\nimport { StartEndPacker } from '../../preprocessing/start_end_packer';\nimport { ValueError } from '../../../../errors';\n\nexport declare interface GPT2PreprocessorArgs extends LayerArgs {\n  /**\n   * A GPT2Tokenizer instance.\n   */\n  tokenizer: GPT2Tokenizer;\n\n  /**\n   * The length of the packed inputs.\n   * Defaults to 1024.\n   */\n  sequenceLength?: number;\n\n  /**\n   * If `true`, the preprocessor will prepend the tokenizer start token to each\n   * input sequence.\n   * Defaults to `true`.\n   */\n  addStartToken?: boolean;\n\n  /**\n   * If `true`, the preprocessor will prepend the tokenizer end token to each\n   * input sequence.\n   * Defaults to `true`.\n   */\n  addEndToken?: boolean;\n}\n\nexport declare interface GPT2PreprocessorOptions {\n  /**\n   * Any label data. Will be passed through unaltered.\n   */\n  y?: Tensor;\n\n  /**\n   * Any label weight data. Will be passed through unaltered.\n   */\n  sampleWeight?: Tensor;\n\n  /**\n   * Pass to override the configured `sequenceLength` of the layer.\n   */\n  sequenceLength?: number;\n}\n\nexport function packXYSampleWeight(\n  x: NamedTensorMap, y?: Tensor, sampleWeight?: Tensor):\n  NamedTensorMap\n  | [NamedTensorMap, Tensor]\n  | [NamedTensorMap, Tensor, Tensor] {\n\n  if (y === undefined) {\n    return x;\n  } else if (sampleWeight === undefined) {\n    return [x, y];\n  } else {\n    return [x, y, sampleWeight];\n  }\n}\n\n/**\n * GPT2 preprocessing layer which tokenizes and packs inputs.\n *\n * This preprocessing layer will do 2 things:\n *\n * - Tokenize the inputs using the `tokenizer`.\n * - Construct a dictionary with keys `\"tokenIds\"`, `\"paddingMask\"`, that can\n *     be passed directly to a `GPT2Backbone`.\n *\n * The call method of this layer accepts three arguments, `x`, `y`, and\n * `sampleWeight`. `x` can be a string or tensor representing a single\n * segment, a list of strings representing a batch of single segments,\n * or a list of tensors representing multiple segments to be packed together.\n * `y` and `sampleWeight` are both optional, can have any format, and will be\n * passed through unaltered.\n *\n * `GPT2Preprocessor` forces the input to have only one segment, as GPT2 is\n * mainly used for generation tasks. For tasks having multi-segment inputs\n * like \"glue/mnli\", please use a model designed for classification purposes\n * such as BERT or RoBERTa.\n *\n * Examples:\n *\n * Directly calling the layer on data.\n * ```js\n * const features =  ['a quick fox.', 'a fox quick.'];\n * const vocabulary =\n *    new Map([['<|endoftext|>', 0], ['a', 4], ['Ġquick', 5], ['Ġfox', 6]]);\n * const merges =\n *    ['Ġ q', 'u i', 'c k', 'ui ck', 'Ġq uick', 'Ġ f', 'o x', 'Ġf ox'];\n * const tokenizer = GPT2Tokenizer({vocabulary, merges});\n *\n * const preprocessor = GPT2Preprocessor({tokenizer});\n * preprocessor.call(tensor(['the quick brown fox jumped.']))[0].print();\n * ```\n */\nexport class GPT2Preprocessor extends Preprocessor {\n  /** @nocollapse */\n  static override className = 'GPT2Preprocessor';\n\n  protected readonly sequenceLength: number;\n  protected readonly addStartToken: boolean;\n  protected readonly addEndToken: boolean;\n  protected readonly packer: StartEndPacker;\n\n  constructor(args: GPT2PreprocessorArgs) {\n    super(args);\n    this.tokenizer = args.tokenizer;\n    this.sequenceLength = args.sequenceLength ?? 1024;\n    this.addStartToken = args.addStartToken ?? true;\n    this.addEndToken = args.addEndToken ?? true;\n\n    const gpt2Tokenizer = this.tokenizer as GPT2Tokenizer;\n    this.packer = new StartEndPacker({\n      startValue: gpt2Tokenizer.startTokenId,\n      endValue: gpt2Tokenizer.endTokenId,\n      padValue: gpt2Tokenizer.padTokenId,\n      sequenceLength: this.sequenceLength,\n    });\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config = {\n      sequenceLength: this.sequenceLength,\n      addStartToken: this.addStartToken,\n      addEndToken: this.addEndToken,\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n\n  override call(\n    inputs: Tensor|Tensor[], kwargs: GPT2PreprocessorOptions): Tensor|Tensor[] {\n    return this.callAndReturnPaddingMask(inputs, kwargs).tokenIds;\n  }\n\n  private callAndReturnPaddingMask(\n    inputs: Tensor|Tensor[],\n    kwargs: GPT2PreprocessorOptions\n  ): NamedTensorMap {\n    return tidy(() => {\n      if (inputs instanceof Array) {\n        if (inputs.length !== 1) {\n          throw new ValueError(\n            'GPT2 requires each input feature to contain only ' +\n            `one segment, but received ${inputs.length}. If you are using ` +\n            'GPT2 for a multi-segment classification task, please refer to ' +\n            'classification models like BERT or RoBERTa.'\n          );\n        }\n        inputs = inputs[0];\n      }\n\n      const sequenceLength = kwargs.sequenceLength ?? this.sequenceLength;\n      const [tokenIds, paddingMask] = this.packer.callAndReturnPaddingMask(\n        this.tokenizer.call(inputs),\n        {\n          sequenceLength,\n          addStartValue: this.addStartToken,\n          addEndValue: this.addEndToken\n        }\n      );\n\n      return {\n        tokenIds: tokenIds as Tensor2D,\n        paddingMask: paddingMask as Tensor2D\n      };\n    });\n  }\n\n  /**\n   * Calls the layer and returns extra information like the paddingMask used to\n   * pack the sequence, the label data, and the sample weights used.\n   */\n  callAndPackArgs(inputs: Tensor|Tensor[], kwargs: GPT2PreprocessorOptions):\n    NamedTensorMap\n    | [NamedTensorMap, Tensor]\n    | [NamedTensorMap, Tensor, Tensor] {\n    const x = this.callAndReturnPaddingMask(inputs, kwargs);\n    return packXYSampleWeight(x, kwargs.y, kwargs.sampleWeight);\n  }\n\n  static override tokenizerCls<T extends serialization.Serializable>(\n    cls: serialization.SerializableConstructor<T>) {\n    return GPT2Tokenizer;\n  }\n}\nserialization.registerClass(GPT2Preprocessor);\n"]}

@@ -26,3 +26,3 @@ /**

/** @nocollapse */
static readonly className = "Preprocessor";
static className: string;
private _tokenizer;

@@ -29,0 +29,0 @@ constructor(args: LayerArgs);

@@ -57,2 +57,2 @@ /**

serialization.registerClass(Preprocessor);
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoicHJlcHJvY2Vzc29yLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1sYXllcnMvc3JjL2xheWVycy9ubHAvbW9kZWxzL3ByZXByb2Nlc3Nvci50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCx1REFBdUQ7QUFDdkQsT0FBTyxFQUFFLGFBQWEsRUFBRSxNQUFNLHVCQUF1QixDQUFDO0FBRXRELE9BQU8sRUFBRSxLQUFLLEVBQWEsTUFBTSwwQkFBMEIsQ0FBQztBQUM1RCxPQUFPLEVBQUUsU0FBUyxFQUFFLE1BQU0sZUFBZSxDQUFDO0FBRTFDLE9BQU8sRUFBRSxzQkFBc0IsRUFBRSxvQkFBb0IsRUFBRSxNQUFNLDhCQUE4QixDQUFDO0FBRTVGOztHQUVHO0FBQ0gsTUFBYSxZQUFhLFNBQVEsS0FBSztJQU1yQyxZQUFZLElBQWU7UUFDekIsS0FBSyxDQUFDLElBQUksQ0FBQyxDQUFDO0lBQ2QsQ0FBQztJQUVEOztPQUVHO0lBQ0gsSUFBSSxTQUFTO1FBQ1gsT0FBTyxJQUFJLENBQUMsVUFBVSxDQUFDO0lBQ3pCLENBQUM7SUFFRCxJQUFJLFNBQVMsQ0FBQyxLQUFnQjtRQUM1QixJQUFJLENBQUMsVUFBVSxHQUFHLEtBQUssQ0FBQztJQUMxQixDQUFDO0lBRVEsU0FBUztRQUNoQixNQUFNLE1BQU0sR0FBRyxLQUFLLENBQUMsU0FBUyxFQUFFLENBQUM7UUFDakMsTUFBTSxDQUFDLFNBQVMsR0FBRyxvQkFBb0IsQ0FBQyxJQUFJLENBQUMsU0FBUyxDQUFDLENBQUM7UUFDeEQsT0FBTyxNQUFNLENBQUM7SUFDaEIsQ0FBQztJQUVELE1BQU0sQ0FBVSxVQUFVLENBQ3hCLEdBQTZDLEVBQzdDLE1BQWdDO1FBRWhDLE1BQU0sTUFBTSxHQUFXLE1BQU0sQ0FBQztRQUU5QixJQUFJLE1BQU0sQ0FBQyxTQUFTLElBQUksSUFBSSxJQUFJLENBQUMsQ0FBQyxNQUFNLENBQUMsU0FBUyxZQUFZLFNBQVMsQ0FBQyxFQUFFO1lBQ3hFLE1BQU0sbUJBQW1CLEdBQUcsTUFBTSxDQUFDLFNBQXFDLENBQUM7WUFFekUsTUFBTSxDQUFDLFNBQVMsR0FBRyxzQkFBc0IsQ0FDdkMsbUJBQW1CLEVBQ25CLGFBQWEsQ0FBQyxnQkFBZ0IsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxZQUFZLEVBQ3BELEVBQUUsRUFBRSxjQUFjLENBQUMsQ0FBQztTQUN2QjtRQUNELE9BQU8sSUFBSSxHQUFHLENBQUMsTUFBTSxDQUFDLENBQUM7SUFDekIsQ0FBQztJQUVELE1BQU0sQ0FBQyxZQUFZLENBQ2pCLEdBQTZDLElBQUcsQ0FBQzs7QUE1Q25ELGtCQUFrQjtBQUNGLHNCQUFTLEdBQUcsY0FBYyxDQUFDO1NBRmhDLFlBQVk7QUErQ3pCLGFBQWEsQ0FBQyxhQUFhLENBQUMsWUFBWSxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMyBHb29nbGUgTExDLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbi8qIE9yaWdpbmFsIHNvdXJjZToga2VyYXMtbmxwL21vZGVscy9wcmVwcm9jZXNzb3IucHkgKi9cbmltcG9ydCB7IHNlcmlhbGl6YXRpb24gfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQgeyBMYXllciwgTGF5ZXJBcmdzIH0gZnJvbSAnLi4vLi4vLi4vZW5naW5lL3RvcG9sb2d5JztcbmltcG9ydCB7IFRva2VuaXplciB9IGZyb20gJy4uL3Rva2VuaXplcnMnO1xuaW1wb3J0IHsgS3dhcmdzIH0gZnJvbSAnLi4vLi4vLi4vdHlwZXMnO1xuaW1wb3J0IHsgZGVzZXJpYWxpemVLZXJhc09iamVjdCwgc2VyaWFsaXplS2VyYXNPYmplY3QgfSBmcm9tICcuLi8uLi8uLi91dGlscy9nZW5lcmljX3V0aWxzJztcblxuLyoqXG4gKiBCYXNlIGNsYXNzIGZvciBtb2RlbCBQcmVwcm9jZXNzb3JzLlxuICovXG5leHBvcnQgY2xhc3MgUHJlcHJvY2Vzc29yIGV4dGVuZHMgTGF5ZXIge1xuICAvKiogQG5vY29sbGFwc2UgKi9cbiAgc3RhdGljIHJlYWRvbmx5IGNsYXNzTmFtZSA9ICdQcmVwcm9jZXNzb3InO1xuXG4gIHByaXZhdGUgX3Rva2VuaXplcjogVG9rZW5pemVyO1xuXG4gIGNvbnN0cnVjdG9yKGFyZ3M6IExheWVyQXJncykge1xuICAgIHN1cGVyKGFyZ3MpO1xuICB9XG5cbiAgLyoqXG4gICAqIFRoZSB0b2tlbml6ZXIgdXNlZCB0byB0b2tlbml6ZSBzdHJpbmdzLlxuICAgKi9cbiAgZ2V0IHRva2VuaXplcigpIHtcbiAgICByZXR1cm4gdGhpcy5fdG9rZW5pemVyO1xuICB9XG5cbiAgc2V0IHRva2VuaXplcih2YWx1ZTogVG9rZW5pemVyKSB7XG4gICAgdGhpcy5fdG9rZW5pemVyID0gdmFsdWU7XG4gIH1cblxuICBvdmVycmlkZSBnZXRDb25maWcoKTogc2VyaWFsaXphdGlvbi5Db25maWdEaWN0IHtcbiAgICBjb25zdCBjb25maWcgPSBzdXBlci5nZXRDb25maWcoKTtcbiAgICBjb25maWcudG9rZW5pemVyID0gc2VyaWFsaXplS2VyYXNPYmplY3QodGhpcy50b2tlbml6ZXIpO1xuICAgIHJldHVybiBjb25maWc7XG4gIH1cblxuICBzdGF0aWMgb3ZlcnJpZGUgZnJvbUNvbmZpZzxUIGV4dGVuZHMgc2VyaWFsaXphdGlvbi5TZXJpYWxpemFibGU+KFxuICAgIGNsczogc2VyaWFsaXphdGlvbi5TZXJpYWxpemFibGVDb25zdHJ1Y3RvcjxUPixcbiAgICBjb25maWc6IHNlcmlhbGl6YXRpb24uQ29uZmlnRGljdFxuICApOiBUIHtcbiAgICBjb25zdCBrd2FyZ3M6IEt3YXJncyA9IGNvbmZpZztcblxuICAgIGlmIChjb25maWcudG9rZW5pemVyICE9IG51bGwgJiYgIShjb25maWcudG9rZW5pemVyIGluc3RhbmNlb2YgVG9rZW5pemVyKSkge1xuICAgICAgY29uc3QgdG9rZW5pemVyQ29uZmlnRGljdCA9IGNvbmZpZy50b2tlbml6ZXIgYXMgc2VyaWFsaXphdGlvbi5Db25maWdEaWN0O1xuXG4gICAgICBrd2FyZ3MudG9rZW5pemVyID0gZGVzZXJpYWxpemVLZXJhc09iamVjdChcbiAgICAgICAgdG9rZW5pemVyQ29uZmlnRGljdCxcbiAgICAgICAgc2VyaWFsaXphdGlvbi5TZXJpYWxpemF0aW9uTWFwLmdldE1hcCgpLmNsYXNzTmFtZU1hcCxcbiAgICAgICAge30sICdwcmVwcm9jZXNzb3InKTtcbiAgICB9XG4gICAgcmV0dXJuIG5ldyBjbHMoa3dhcmdzKTtcbiAgfVxuXG4gIHN0YXRpYyB0b2tlbml6ZXJDbHM8VCBleHRlbmRzIHNlcmlhbGl6YXRpb24uU2VyaWFsaXphYmxlPihcbiAgICBjbHM6IHNlcmlhbGl6YXRpb24uU2VyaWFsaXphYmxlQ29uc3RydWN0b3I8VD4pIHt9XG59XG5zZXJpYWxpemF0aW9uLnJlZ2lzdGVyQ2xhc3MoUHJlcHJvY2Vzc29yKTtcbiJdfQ==
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoicHJlcHJvY2Vzc29yLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1sYXllcnMvc3JjL2xheWVycy9ubHAvbW9kZWxzL3ByZXByb2Nlc3Nvci50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCx1REFBdUQ7QUFDdkQsT0FBTyxFQUFFLGFBQWEsRUFBRSxNQUFNLHVCQUF1QixDQUFDO0FBRXRELE9BQU8sRUFBRSxLQUFLLEVBQWEsTUFBTSwwQkFBMEIsQ0FBQztBQUM1RCxPQUFPLEVBQUUsU0FBUyxFQUFFLE1BQU0sZUFBZSxDQUFDO0FBRTFDLE9BQU8sRUFBRSxzQkFBc0IsRUFBRSxvQkFBb0IsRUFBRSxNQUFNLDhCQUE4QixDQUFDO0FBRTVGOztHQUVHO0FBQ0gsTUFBYSxZQUFhLFNBQVEsS0FBSztJQU1yQyxZQUFZLElBQWU7UUFDekIsS0FBSyxDQUFDLElBQUksQ0FBQyxDQUFDO0lBQ2QsQ0FBQztJQUVEOztPQUVHO0lBQ0gsSUFBSSxTQUFTO1FBQ1gsT0FBTyxJQUFJLENBQUMsVUFBVSxDQUFDO0lBQ3pCLENBQUM7SUFFRCxJQUFJLFNBQVMsQ0FBQyxLQUFnQjtRQUM1QixJQUFJLENBQUMsVUFBVSxHQUFHLEtBQUssQ0FBQztJQUMxQixDQUFDO0lBRVEsU0FBUztRQUNoQixNQUFNLE1BQU0sR0FBRyxLQUFLLENBQUMsU0FBUyxFQUFFLENBQUM7UUFDakMsTUFBTSxDQUFDLFNBQVMsR0FBRyxvQkFBb0IsQ0FBQyxJQUFJLENBQUMsU0FBUyxDQUFDLENBQUM7UUFDeEQsT0FBTyxNQUFNLENBQUM7SUFDaEIsQ0FBQztJQUVELE1BQU0sQ0FBVSxVQUFVLENBQ3hCLEdBQTZDLEVBQzdDLE1BQWdDO1FBRWhDLE1BQU0sTUFBTSxHQUFXLE1BQU0sQ0FBQztRQUU5QixJQUFJLE1BQU0sQ0FBQyxTQUFTLElBQUksSUFBSSxJQUFJLENBQUMsQ0FBQyxNQUFNLENBQUMsU0FBUyxZQUFZLFNBQVMsQ0FBQyxFQUFFO1lBQ3hFLE1BQU0sbUJBQW1CLEdBQUcsTUFBTSxDQUFDLFNBQXFDLENBQUM7WUFFekUsTUFBTSxDQUFDLFNBQVMsR0FBRyxzQkFBc0IsQ0FDdkMsbUJBQW1CLEVBQ25CLGFBQWEsQ0FBQyxnQkFBZ0IsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxZQUFZLEVBQ3BELEVBQUUsRUFBRSxjQUFjLENBQUMsQ0FBQztTQUN2QjtRQUNELE9BQU8sSUFBSSxHQUFHLENBQUMsTUFBTSxDQUFDLENBQUM7SUFDekIsQ0FBQztJQUVELE1BQU0sQ0FBQyxZQUFZLENBQ2pCLEdBQTZDLElBQUcsQ0FBQzs7QUE1Q25ELGtCQUFrQjtBQUNYLHNCQUFTLEdBQUcsY0FBYyxDQUFDO1NBRnZCLFlBQVk7QUErQ3pCLGFBQWEsQ0FBQyxhQUFhLENBQUMsWUFBWSxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMyBHb29nbGUgTExDLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbi8qIE9yaWdpbmFsIHNvdXJjZToga2VyYXMtbmxwL21vZGVscy9wcmVwcm9jZXNzb3IucHkgKi9cbmltcG9ydCB7IHNlcmlhbGl6YXRpb24gfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQgeyBMYXllciwgTGF5ZXJBcmdzIH0gZnJvbSAnLi4vLi4vLi4vZW5naW5lL3RvcG9sb2d5JztcbmltcG9ydCB7IFRva2VuaXplciB9IGZyb20gJy4uL3Rva2VuaXplcnMnO1xuaW1wb3J0IHsgS3dhcmdzIH0gZnJvbSAnLi4vLi4vLi4vdHlwZXMnO1xuaW1wb3J0IHsgZGVzZXJpYWxpemVLZXJhc09iamVjdCwgc2VyaWFsaXplS2VyYXNPYmplY3QgfSBmcm9tICcuLi8uLi8uLi91dGlscy9nZW5lcmljX3V0aWxzJztcblxuLyoqXG4gKiBCYXNlIGNsYXNzIGZvciBtb2RlbCBQcmVwcm9jZXNzb3JzLlxuICovXG5leHBvcnQgY2xhc3MgUHJlcHJvY2Vzc29yIGV4dGVuZHMgTGF5ZXIge1xuICAvKiogQG5vY29sbGFwc2UgKi9cbiAgc3RhdGljIGNsYXNzTmFtZSA9ICdQcmVwcm9jZXNzb3InO1xuXG4gIHByaXZhdGUgX3Rva2VuaXplcjogVG9rZW5pemVyO1xuXG4gIGNvbnN0cnVjdG9yKGFyZ3M6IExheWVyQXJncykge1xuICAgIHN1cGVyKGFyZ3MpO1xuICB9XG5cbiAgLyoqXG4gICAqIFRoZSB0b2tlbml6ZXIgdXNlZCB0byB0b2tlbml6ZSBzdHJpbmdzLlxuICAgKi9cbiAgZ2V0IHRva2VuaXplcigpIHtcbiAgICByZXR1cm4gdGhpcy5fdG9rZW5pemVyO1xuICB9XG5cbiAgc2V0IHRva2VuaXplcih2YWx1ZTogVG9rZW5pemVyKSB7XG4gICAgdGhpcy5fdG9rZW5pemVyID0gdmFsdWU7XG4gIH1cblxuICBvdmVycmlkZSBnZXRDb25maWcoKTogc2VyaWFsaXphdGlvbi5Db25maWdEaWN0IHtcbiAgICBjb25zdCBjb25maWcgPSBzdXBlci5nZXRDb25maWcoKTtcbiAgICBjb25maWcudG9rZW5pemVyID0gc2VyaWFsaXplS2VyYXNPYmplY3QodGhpcy50b2tlbml6ZXIpO1xuICAgIHJldHVybiBjb25maWc7XG4gIH1cblxuICBzdGF0aWMgb3ZlcnJpZGUgZnJvbUNvbmZpZzxUIGV4dGVuZHMgc2VyaWFsaXphdGlvbi5TZXJpYWxpemFibGU+KFxuICAgIGNsczogc2VyaWFsaXphdGlvbi5TZXJpYWxpemFibGVDb25zdHJ1Y3RvcjxUPixcbiAgICBjb25maWc6IHNlcmlhbGl6YXRpb24uQ29uZmlnRGljdFxuICApOiBUIHtcbiAgICBjb25zdCBrd2FyZ3M6IEt3YXJncyA9IGNvbmZpZztcblxuICAgIGlmIChjb25maWcudG9rZW5pemVyICE9IG51bGwgJiYgIShjb25maWcudG9rZW5pemVyIGluc3RhbmNlb2YgVG9rZW5pemVyKSkge1xuICAgICAgY29uc3QgdG9rZW5pemVyQ29uZmlnRGljdCA9IGNvbmZpZy50b2tlbml6ZXIgYXMgc2VyaWFsaXphdGlvbi5Db25maWdEaWN0O1xuXG4gICAgICBrd2FyZ3MudG9rZW5pemVyID0gZGVzZXJpYWxpemVLZXJhc09iamVjdChcbiAgICAgICAgdG9rZW5pemVyQ29uZmlnRGljdCxcbiAgICAgICAgc2VyaWFsaXphdGlvbi5TZXJpYWxpemF0aW9uTWFwLmdldE1hcCgpLmNsYXNzTmFtZU1hcCxcbiAgICAgICAge30sICdwcmVwcm9jZXNzb3InKTtcbiAgICB9XG4gICAgcmV0dXJuIG5ldyBjbHMoa3dhcmdzKTtcbiAgfVxuXG4gIHN0YXRpYyB0b2tlbml6ZXJDbHM8VCBleHRlbmRzIHNlcmlhbGl6YXRpb24uU2VyaWFsaXphYmxlPihcbiAgICBjbHM6IHNlcmlhbGl6YXRpb24uU2VyaWFsaXphYmxlQ29uc3RydWN0b3I8VD4pIHt9XG59XG5zZXJpYWxpemF0aW9uLnJlZ2lzdGVyQ2xhc3MoUHJlcHJvY2Vzc29yKTtcbiJdfQ==

@@ -21,8 +21,12 @@ /**

*/
import { Tensor, Tensor1D, Tensor2D, serialization } from '@tensorflow/tfjs-core';
import { ConstraintIdentifier } from '../../constraints';
import { Layer, LayerArgs } from '../../engine/topology';
import { InitializerIdentifier } from '../../initializers';
import { Tensor, serialization } from '@tensorflow/tfjs-core';
import { Constraint, ConstraintIdentifier } from '../../constraints';
import { Layer, LayerArgs, SymbolicTensor } from '../../engine/topology';
import { Initializer, InitializerIdentifier } from '../../initializers';
import { Shape } from '../../keras_format/common';
import { RegularizerIdentifier } from '../../regularizers';
import { Regularizer, RegularizerIdentifier } from '../../regularizers';
import { Kwargs } from '../../types';
import { Softmax } from '../advanced_activations';
import { Dropout } from '../core';
import { EinsumDense } from './einsum_dense';
export declare interface MultiHeadAttentionArgs extends LayerArgs {

@@ -62,3 +66,3 @@ /**

*/
attentionAxes: number[];
attentionAxes?: number[] | number;
/**

@@ -68,3 +72,3 @@ * Initializer for dense layer kernels.

*/
kernelInitializer?: InitializerIdentifier;
kernelInitializer?: Initializer | InitializerIdentifier;
/**

@@ -74,23 +78,23 @@ * Initializer for dense layer biases.

*/
biasInitializer?: InitializerIdentifier;
biasInitializer?: Initializer | InitializerIdentifier;
/**
* Regularizer for dense layer kernels.
*/
kernelRegularizer?: RegularizerIdentifier;
kernelRegularizer?: Regularizer | RegularizerIdentifier;
/**
* Regularizer for dense layer biases.
*/
biasRegularizer?: RegularizerIdentifier;
biasRegularizer?: Regularizer | RegularizerIdentifier;
/**
* Regularizer for dense layer activity.
*/
activityRegularizer?: RegularizerIdentifier;
activityRegularizer?: Regularizer | RegularizerIdentifier;
/**
* Constraint for dense layer kernels.
*/
kernelConstraint?: ConstraintIdentifier;
kernelConstraint?: Constraint | ConstraintIdentifier;
/**
* Constraint for dense layer kernels.
*/
biasConstraint?: ConstraintIdentifier;
biasConstraint?: Constraint | ConstraintIdentifier;
}

@@ -165,2 +169,3 @@ export declare interface MultiHeadAttentionOptions {

*
* ```js
* const layer = new MultiHeadAttention({numHeads: 2, keyDim: 2});

@@ -173,5 +178,7 @@ * const target = tf.input({shape: [8, 16]});

* console.log(weights.shape); // [null, 2, 8, 4]
* ```
*
* Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
*
* ```js
* const layer = new MultiHeadAttention({

@@ -182,2 +189,3 @@ * numHeads: 2, keyDim: 2, attentionAxes: [2, 3]});

* console.log(outputTensor.shape); // [null, 5, 3, 4, 16]
* ```
*

@@ -195,3 +203,44 @@ * Returns:

static readonly className = "MultiHeadAttention";
protected readonly numHeads: number;
protected readonly keyDim: number;
protected readonly valueDim: number;
protected readonly dropout: number;
protected readonly useBias: boolean;
protected readonly _outputShape: Shape;
protected readonly kernelInitializer: Initializer;
protected readonly biasInitializer: Initializer;
protected readonly kernelRegularizer: Regularizer;
protected readonly biasRegularizer: Regularizer;
protected readonly kernelConstraint: Constraint;
protected readonly biasConstraint: Constraint;
protected dotProductEquation: string;
protected combineEquation: string;
protected attentionAxes: number[];
protected builtFromSignature: boolean;
protected softmax: Softmax;
protected dropoutLayer: Dropout;
protected queryShape: Shape;
protected keyShape: Shape;
protected valueShape: Shape;
protected queryDense: EinsumDense;
protected keyDense: EinsumDense;
protected valueDense: EinsumDense;
protected outputDense: EinsumDense;
constructor(args: MultiHeadAttentionArgs);
/**
* Should be used for testing purposes only.
*/
get _queryDense(): EinsumDense;
/**
* Should be used for testing purposes only.
*/
get _keyDense(): EinsumDense;
/**
* Should be used for testing purposes only.
*/
get _valueDense(): EinsumDense;
/**
* Should be used for testing purposes only.
*/
get _outputDense(): EinsumDense;
getConfig(): serialization.ConfigDict;

@@ -204,3 +253,3 @@ static fromConfig<T extends serialization.Serializable>(cls: serialization.SerializableConstructor<T>, config: serialization.ConfigDict): T;

*/
private buildFromSignature;
buildFromSignature(queryShape: Shape, valueShape: Shape, keyShape?: Shape): void;
private getCommonKwargsForSublayer;

@@ -219,3 +268,3 @@ /**

*
* This function builds attributes necessary for `_compute_attention` to
* This function builds attributes necessary for `computeAttention` to
* customize attention computation to replace the default dot-product

@@ -226,4 +275,4 @@ * attention.

*/
private buildAttention;
private maskedSoftmax;
protected buildAttention(rank: number): void;
protected maskedSoftmax(attentionScores: Tensor, attentionMask?: Tensor): Tensor;
/**

@@ -236,5 +285,5 @@ * Applies Dot-product attention with query, key, value tensors.

*
* @param query Projected query `Tensor` of shape `(B, T, N, key_dim)`.
* @param key Projected key `Tensor` of shape `(B, S, N, key_dim)`.
* @param value Projected value `Tensor` of shape `(B, S, N, value_dim)`.
* @param query Projected query `Tensor` of shape `(B, T, N, keyDim)`.
* @param key Projected key `Tensor` of shape `(B, S, N, keyDim)`.
* @param value Projected value `Tensor` of shape `(B, S, N, valueDim)`.
* @param attentionMask A boolean mask of shape `(B, T, S)`, that prevents

@@ -249,8 +298,9 @@ * attention to certain positions. It is generally not needed if

*/
private computeAttention;
call(query: Tensor, kwargs: MultiHeadAttentionOptions): Tensor | Tensor2D;
protected computeAttention(query: Tensor, key: Tensor, value: Tensor, attentionMask?: Tensor, training?: boolean): [Tensor, Tensor];
apply(inputs: Tensor | SymbolicTensor, kwargs?: Kwargs): Tensor | Tensor[] | SymbolicTensor | SymbolicTensor[];
call(query: Tensor, kwargs: MultiHeadAttentionOptions): Tensor;
/**
* Exactly like `call` except also returns the attention scores.
*/
callAndReturnAttentionScores(query: Tensor, kwargs: MultiHeadAttentionOptions): [Tensor1D | Tensor2D, Tensor1D | Tensor2D];
callAndReturnAttentionScores(query: Tensor, { value, key, useCausalMask, attentionMask, training }: MultiHeadAttentionOptions): [Tensor, Tensor];
/**

@@ -271,5 +321,5 @@ * Computes the attention mask.

*
* @param query Projected query `Tensor` of shape `(B, T, N, key_dim)`.
* @param key Projected key `Tensor` of shape `(B, S, N, key_dim)`.
* @param value Projected value `Tensor` of shape `(B, S, N, value_dim)`.
* @param query Projected query `Tensor` of shape `(B, T, N, keyDim)`.
* @param key Projected key `Tensor` of shape `(B, S, N, keyDim)`.
* @param value Projected value `Tensor` of shape `(B, S, N, valueDim)`.
* @param attentionMask A boolean mask of shape `(B, T, S)`, that prevents

@@ -304,3 +354,3 @@ * attention to certain positions.

*/
private computeCasualMask;
private computeCausalMask;
/**

@@ -312,3 +362,3 @@ *

*/
computeOutputShape(inputShapes: [Shape, Shape] | [Shape, Shape, Shape]): Shape;
computeOutputShape(inputShapes: [Shape, Shape, Shape | null]): Shape;
}

@@ -18,3 +18,7 @@ /**

/// <amd-module name="@tensorflow/tfjs-layers/dist/layers/nlp/utils" />
import { Tensor } from '@tensorflow/tfjs-core';
import { ModelPredictConfig, Scalar, Tensor } from '@tensorflow/tfjs-core';
import { History } from '../../base_callbacks';
import { ContainerArgs } from '../../engine/container';
import { LayersModel, ModelEvaluateArgs } from '../../engine/training';
import { ModelFitArgs } from '../../engine/training_tensors';
export declare function tensorToArr(input: Tensor): unknown[];

@@ -33,1 +37,41 @@ export declare function tensorArrTo2DArr(inputs: Tensor[]): unknown[][];

export declare function sliceUpdate(inputs: Tensor, startIndices: number[], updates: Tensor): Tensor;
/**
* A model which allows automatically applying preprocessing.
*/
export interface PipelineModelArgs extends ContainerArgs {
/**
* Defaults to true.
*/
includePreprocessing?: boolean;
}
export declare class PipelineModel extends LayersModel {
/** @nocollapse */
static className: string;
protected includePreprocessing: boolean;
constructor(args: PipelineModelArgs);
/**
* An overridable function which preprocesses features.
*/
preprocessFeatures(x: Tensor): Tensor<import("@tensorflow/tfjs-core").Rank>;
/**
* An overridable function which preprocesses labels.
*/
preprocessLabels(y: Tensor): Tensor<import("@tensorflow/tfjs-core").Rank>;
/**
* An overridable function which preprocesses entire samples.
*/
preprocessSamples(x: Tensor, y?: Tensor, sampleWeight?: Tensor): Tensor | [Tensor, Tensor] | [Tensor, Tensor, Tensor];
fit(x: Tensor | Tensor[] | {
[inputName: string]: Tensor;
}, y: Tensor | Tensor[] | {
[inputName: string]: Tensor;
}, args?: ModelFitArgs): Promise<History>;
evaluate(x: Tensor | Tensor[], y: Tensor | Tensor[], args?: ModelEvaluateArgs): Scalar | Scalar[];
predict(x: Tensor | Tensor[], args?: ModelPredictConfig): Tensor | Tensor[];
trainOnBatch(x: Tensor | Tensor[] | {
[inputName: string]: Tensor;
}, y: Tensor | Tensor[] | {
[inputName: string]: Tensor;
}, sampleWeight?: Tensor): Promise<number | number[]>;
predictOnBatch(x: Tensor | Tensor[]): Tensor | Tensor[];
}

@@ -18,2 +18,4 @@ /**

import { tensorScatterUpdate, tidy } from '@tensorflow/tfjs-core';
import { LayersModel } from '../../engine/training';
import { NotImplementedError } from '../../errors';
export function tensorToArr(input) {

@@ -61,2 +63,62 @@ return Array.from(input.dataSync());

}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoidXRpbHMuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWxheWVycy9zcmMvbGF5ZXJzL25scC91dGlscy50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQVUsbUJBQW1CLEVBQUUsSUFBSSxFQUFFLE1BQU0sdUJBQXVCLENBQUM7QUFFMUUsTUFBTSxVQUFVLFdBQVcsQ0FBQyxLQUFhO0lBQ3ZDLE9BQU8sS0FBSyxDQUFDLElBQUksQ0FBQyxLQUFLLENBQUMsUUFBUSxFQUFFLENBQXlCLENBQUM7QUFDOUQsQ0FBQztBQUVELE1BQU0sVUFBVSxnQkFBZ0IsQ0FBQyxNQUFnQjtJQUMvQyxPQUFPLE1BQU0sQ0FBQyxHQUFHLENBQUMsS0FBSyxDQUFDLEVBQUUsQ0FBQyxXQUFXLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQztBQUNqRCxDQUFDO0FBRUQ7Ozs7Ozs7OztHQVNHO0FBQ0gsTUFBTSxVQUFVLFdBQVcsQ0FDdkIsTUFBYyxFQUFFLFlBQXNCLEVBQUUsT0FBZTtJQUN6RCxPQUFPLElBQUksQ0FBQyxHQUFHLEVBQUU7UUFDZixNQUFNLE9BQU8sR0FBZSxFQUFFLENBQUM7UUFDL0I7OztXQUdHO1FBQ0gsU0FBUyxhQUFhLENBQUMsR0FBVyxFQUFFLElBQWM7WUFDaEQsSUFBSSxJQUFJLENBQUMsTUFBTSxLQUFLLFlBQVksQ0FBQyxNQUFNLEVBQUU7Z0JBQ3ZDLE9BQU8sQ0FBQyxJQUFJLENBQUMsSUFBSSxDQUFDLEtBQUssRUFBRSxDQUFDLENBQUM7Z0JBQzNCLE9BQU87YUFDUjtZQUNELE1BQU0sS0FBSyxHQUFHLFlBQVksQ0FBQyxHQUFHLENBQUMsQ0FBQztZQUNoQyxNQUFNLEdBQUcsR0FBRyxLQUFLLEdBQUcsT0FBTyxDQUFDLEtBQUssQ0FBQyxHQUFHLENBQUMsQ0FBQztZQUN2QyxLQUFLLElBQUksQ0FBQyxHQUFHLEtBQUssRUFBRSxDQUFDLEdBQUcsR0FBRyxFQUFFLENBQUMsRUFBRSxFQUFFO2dCQUNoQyxJQUFJLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO2dCQUNiLGFBQWEsQ0FBQyxHQUFHLEdBQUcsQ0FBQyxFQUFFLElBQUksQ0FBQyxDQUFDO2dCQUM3QixJQUFJLENBQUMsR0FBRyxFQUFFLENBQUM7YUFDWjtRQUNILENBQUM7UUFDRCxhQUFhLENBQUMsQ0FBQyxFQUFFLEVBQUUsQ0FBQyxDQUFDO1FBQ3JCLDZEQUE2RDtRQUM3RCxPQUFPLEdBQUcsT0FBTyxDQUFDLE9BQU8sQ0FBQyxDQUFDLE9BQU8sQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDO1FBQzFDLE9BQU8sbUJBQW1CLENBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxPQUFPLENBQUMsQ0FBQztJQUN2RCxDQUFDLENBQUMsQ0FBQztBQUNMLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMyBHb29nbGUgTExDLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7IFRlbnNvciwgdGVuc29yU2NhdHRlclVwZGF0ZSwgdGlkeSB9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmV4cG9ydCBmdW5jdGlvbiB0ZW5zb3JUb0FycihpbnB1dDogVGVuc29yKTogdW5rbm93bltdIHtcbiAgcmV0dXJuIEFycmF5LmZyb20oaW5wdXQuZGF0YVN5bmMoKSkgYXMgdW5rbm93biBhcyB1bmtub3duW107XG59XG5cbmV4cG9ydCBmdW5jdGlvbiB0ZW5zb3JBcnJUbzJEQXJyKGlucHV0czogVGVuc29yW10pOiB1bmtub3duW11bXSB7XG4gIHJldHVybiBpbnB1dHMubWFwKGlucHV0ID0+IHRlbnNvclRvQXJyKGlucHV0KSk7XG59XG5cbi8qKlxuICogUmV0dXJucyBhIG5ldyBUZW5zb3Igd2l0aCBgdXBkYXRlc2AgaW5zZXJ0ZWQgaW50byBgaW5wdXRzYCBzdGFydGluZyBhdCB0aGVcbiAqIGluZGV4IGBzdGFydEluZGljZXNgLlxuICpcbiAqIEBwYXJhbSBpbnB1dHMgVGVuc29yIHRvIFwibW9kaWZ5XCJcbiAqIEBwYXJhbSBzdGFydEluZGljZXMgdGhlIHN0YXJ0aW5nIGluZGV4IHRvIGluc2VydCB0aGUgc2xpY2UuXG4gKiAgTGVuZ3RoIG11c3QgYmUgZXF1YWwgdG8gYGlucHV0cy5yYW5rYDtcbiAqIEBwYXJhbSB1cGRhdGVzIHRoZSB1cGRhdGUgdGVuc29yLiBTaGFwZSBtdXN0IGZpdCB3aXRoaW4gYGlucHV0c2Agc2hhcGUuXG4gKiBAcmV0dXJucyBhIG5ldyB0ZW5zb3Igd2l0aCB0aGUgbW9kaWZpY2F0aW9uLlxuICovXG5leHBvcnQgZnVuY3Rpb24gc2xpY2VVcGRhdGUoXG4gICAgaW5wdXRzOiBUZW5zb3IsIHN0YXJ0SW5kaWNlczogbnVtYmVyW10sIHVwZGF0ZXM6IFRlbnNvcik6IFRlbnNvciB7XG4gIHJldHVybiB0aWR5KCgpID0+IHtcbiAgICBjb25zdCBpbmRpY2VzOiBudW1iZXJbXVtdID0gW107XG4gICAgLyoqXG4gICAgICogQ29tcHV0ZXMgdGhlIHVwZGF0ZSBpbmRpY2VzIGJ5IGl0ZXJhdGluZyB0aHJvdWdoIGFsbCBpbmRpY2VzIGZyb21cbiAgICAgKiBgc3RhcnRJbmRpY2VzYCB0byBgc3RhcnRJbmRpY2VzICsgdXBkYXRlcy5zaGFwZWAuXG4gICAgICovXG4gICAgZnVuY3Rpb24gY3JlYXRlSW5kaWNlcyhpZHg6IG51bWJlciwgY3VycjogbnVtYmVyW10pOiB2b2lkIHtcbiAgICAgIGlmIChjdXJyLmxlbmd0aCA9PT0gc3RhcnRJbmRpY2VzLmxlbmd0aCkge1xuICAgICAgICBpbmRpY2VzLnB1c2goY3Vyci5zbGljZSgpKTtcbiAgICAgICAgcmV0dXJuO1xuICAgICAgfVxuICAgICAgY29uc3Qgc3RhcnQgPSBzdGFydEluZGljZXNbaWR4XTtcbiAgICAgIGNvbnN0IGVuZCA9IHN0YXJ0ICsgdXBkYXRlcy5zaGFwZVtpZHhdO1xuICAgICAgZm9yIChsZXQgaSA9IHN0YXJ0OyBpIDwgZW5kOyBpKyspIHtcbiAgICAgICAgY3Vyci5wdXNoKGkpO1xuICAgICAgICBjcmVhdGVJbmRpY2VzKGlkeCArIDEsIGN1cnIpO1xuICAgICAgICBjdXJyLnBvcCgpO1xuICAgICAgfVxuICAgIH1cbiAgICBjcmVhdGVJbmRpY2VzKDAsIFtdKTtcbiAgICAvLyBGbGF0dGVuIHRoZSB1cGRhdGVzIHRvIG1hdGNoIGxlbmd0aCBvZiBpdHMgdXBkYXRlIGluZGljZXMuXG4gICAgdXBkYXRlcyA9IHVwZGF0ZXMucmVzaGFwZShbdXBkYXRlcy5zaXplXSk7XG4gICAgcmV0dXJuIHRlbnNvclNjYXR0ZXJVcGRhdGUoaW5wdXRzLCBpbmRpY2VzLCB1cGRhdGVzKTtcbiAgfSk7XG59XG4iXX0=
function packXYSampleWeight(x, y, sampleWeight) {
throw new NotImplementedError();
}
function unPackXYSampleWeight(data) {
throw new NotImplementedError();
}
// TODO(pforderique): Figure out a workaround for `tf.data.Dataset`.
function convertInputsToDataset(x, y, sampleWeight, batchSize) {
throw new NotImplementedError();
}
function trainValidationSplit(arrays, validationSplit) {
throw new NotImplementedError();
}
class PipelineModel extends LayersModel {
constructor(args) {
var _a;
super(args);
this.includePreprocessing = (_a = args.includePreprocessing) !== null && _a !== void 0 ? _a : true;
}
/**
* An overridable function which preprocesses features.
*/
preprocessFeatures(x) {
return x;
}
/**
* An overridable function which preprocesses labels.
*/
preprocessLabels(y) {
return y;
}
/**
* An overridable function which preprocesses entire samples.
*/
preprocessSamples(x, y, sampleWeight) {
throw new NotImplementedError();
}
// ---------------------------------------------------------------------------
// Below are overrides to LayersModel methods to apply the functions above.
// ---------------------------------------------------------------------------
fit(x, y, args = {}) {
throw new NotImplementedError(`Uses ${convertInputsToDataset}, ${trainValidationSplit} ` +
`${packXYSampleWeight}, and ${unPackXYSampleWeight}`);
}
evaluate(x, y, args) {
throw new NotImplementedError();
}
predict(x, args) {
throw new NotImplementedError();
}
trainOnBatch(x, y, sampleWeight) {
throw new NotImplementedError();
}
predictOnBatch(x) {
throw new NotImplementedError();
}
}
/** @nocollapse */
PipelineModel.className = 'PipelineModel';
export { PipelineModel };
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"utils.js","sourceRoot":"","sources":["../../../../../../../tfjs-layers/src/layers/nlp/utils.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAsC,mBAAmB,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAItG,OAAO,EAAE,WAAW,EAAqB,MAAM,uBAAuB,CAAC;AAEvE,OAAO,EAAE,mBAAmB,EAAE,MAAM,cAAc,CAAC;AAEnD,MAAM,UAAU,WAAW,CAAC,KAAa;IACvC,OAAO,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,QAAQ,EAAE,CAAyB,CAAC;AAC9D,CAAC;AAED,MAAM,UAAU,gBAAgB,CAAC,MAAgB;IAC/C,OAAO,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,EAAE,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC,CAAC;AACjD,CAAC;AAED;;;;;;;;;GASG;AACH,MAAM,UAAU,WAAW,CACvB,MAAc,EAAE,YAAsB,EAAE,OAAe;IACzD,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,MAAM,OAAO,GAAe,EAAE,CAAC;QAC/B;;;WAGG;QACH,SAAS,aAAa,CAAC,GAAW,EAAE,IAAc;YAChD,IAAI,IAAI,CAAC,MAAM,KAAK,YAAY,CAAC,MAAM,EAAE;gBACvC,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,EAAE,CAAC,CAAC;gBAC3B,OAAO;aACR;YACD,MAAM,KAAK,GAAG,YAAY,CAAC,GAAG,CAAC,CAAC;YAChC,MAAM,GAAG,GAAG,KAAK,GAAG,OAAO,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC;YACvC,KAAK,IAAI,CAAC,GAAG,KAAK,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE;gBAChC,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;gBACb,aAAa,CAAC,GAAG,GAAG,CAAC,EAAE,IAAI,CAAC,CAAC;gBAC7B,IAAI,CAAC,GAAG,EAAE,CAAC;aACZ;QACH,CAAC;QACD,aAAa,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QACrB,6DAA6D;QAC7D,OAAO,GAAG,OAAO,CAAC,OAAO,CAAC,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC,CAAC;QAC1C,OAAO,mBAAmB,CAAC,MAAM,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC;IACvD,CAAC,CAAC,CAAC;AACL,CAAC;AAED,SAAS,kBAAkB,CAAC,CAAS,EAAE,CAAU,EAAE,YAAqB;IAItE,MAAM,IAAI,mBAAmB,EAAE,CAAC;AAClC,CAAC;AAED,SAAS,oBAAoB,CAC3B,IAAwD;IAExD,MAAM,IAAI,mBAAmB,EAAE,CAAC;AAClC,CAAC;AAED,oEAAoE;AACpE,SAAS,sBAAsB,CAC7B,CAAU,EAAE,CAAU,EAAE,YAAqB,EAAE,SAAkB;IAEjE,MAAM,IAAI,mBAAmB,EAAE,CAAC;AAClC,CAAC;AAED,SAAS,oBAAoB,CAAC,MAAgB,EAAE,eAAuB;IACrE,MAAM,IAAI,mBAAmB,EAAE,CAAC;AAClC,CAAC;AAYD,MAAa,aAAc,SAAQ,WAAW;IAM5C,YAAY,IAAuB;;QACjC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,CAAC,oBAAoB,GAAG,MAAA,IAAI,CAAC,oBAAoB,mCAAI,IAAI,CAAC;IAChE,CAAC;IAED;;OAEG;IACH,kBAAkB,CAAC,CAAS;QAC1B,OAAO,CAAC,CAAC;IACX,CAAC;IAED;;OAEG;IACH,gBAAgB,CAAC,CAAS;QACxB,OAAO,CAAC,CAAC;IACX,CAAC;IAED;;OAEG;IACH,iBAAiB,CAAC,CAAS,EAAE,CAAU,EAAE,YAAqB;QAI5D,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;IAED,8EAA8E;IAC9E,2EAA2E;IAC3E,8EAA8E;IACrE,GAAG,CACV,CAAgD,EAChD,CAAgD,EAChD,OAAqB,EAAE;QAEvB,MAAM,IAAI,mBAAmB,CAC3B,QAAQ,sBAAsB,KAAK,oBAAoB,GAAG;YAC1D,GAAG,kBAAkB,SAAS,oBAAoB,EAAE,CAAC,CAAC;IAC1D,CAAC;IAEQ,QAAQ,CACf,CAAkB,EAClB,CAAkB,EAClB,IAAwB;QAExB,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;IAEQ,OAAO,CACd,CAAoB,EACpB,IAAyB;QAEzB,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;IAEQ,YAAY,CACnB,CAAgD,EAChD,CAAgD,EAChD,YAAqB;QAErB,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;IAEQ,cAAc,CAAC,CAAkB;QACxC,MAAM,IAAI,mBAAmB,EAAE,CAAC;IAClC,CAAC;;AAxED,kBAAkB;AACF,uBAAS,GAAG,eAAe,CAAC;SAFjC,aAAa","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport { ModelPredictConfig, Scalar, Tensor, tensorScatterUpdate, tidy } from '@tensorflow/tfjs-core';\n\nimport { History } from '../../base_callbacks';\nimport { ContainerArgs } from '../../engine/container';\nimport { LayersModel, ModelEvaluateArgs } from '../../engine/training';\nimport { ModelFitArgs } from '../../engine/training_tensors';\nimport { NotImplementedError } from '../../errors';\n\nexport function tensorToArr(input: Tensor): unknown[] {\n  return Array.from(input.dataSync()) as unknown as unknown[];\n}\n\nexport function tensorArrTo2DArr(inputs: Tensor[]): unknown[][] {\n  return inputs.map(input => tensorToArr(input));\n}\n\n/**\n * Returns a new Tensor with `updates` inserted into `inputs` starting at the\n * index `startIndices`.\n *\n * @param inputs Tensor to \"modify\"\n * @param startIndices the starting index to insert the slice.\n *  Length must be equal to `inputs.rank`;\n * @param updates the update tensor. Shape must fit within `inputs` shape.\n * @returns a new tensor with the modification.\n */\nexport function sliceUpdate(\n    inputs: Tensor, startIndices: number[], updates: Tensor): Tensor {\n  return tidy(() => {\n    const indices: number[][] = [];\n    /**\n     * Computes the update indices by iterating through all indices from\n     * `startIndices` to `startIndices + updates.shape`.\n     */\n    function createIndices(idx: number, curr: number[]): void {\n      if (curr.length === startIndices.length) {\n        indices.push(curr.slice());\n        return;\n      }\n      const start = startIndices[idx];\n      const end = start + updates.shape[idx];\n      for (let i = start; i < end; i++) {\n        curr.push(i);\n        createIndices(idx + 1, curr);\n        curr.pop();\n      }\n    }\n    createIndices(0, []);\n    // Flatten the updates to match length of its update indices.\n    updates = updates.reshape([updates.size]);\n    return tensorScatterUpdate(inputs, indices, updates);\n  });\n}\n\nfunction packXYSampleWeight(x: Tensor, y?: Tensor, sampleWeight?: Tensor):\n  Tensor\n  | [Tensor, Tensor]\n  | [Tensor, Tensor, Tensor] {\n  throw new NotImplementedError();\n}\n\nfunction unPackXYSampleWeight(\n  data: [Tensor]|[Tensor, Tensor]|[Tensor, Tensor, Tensor]\n) {\n  throw new NotImplementedError();\n}\n\n// TODO(pforderique): Figure out a workaround for `tf.data.Dataset`.\nfunction convertInputsToDataset(\n  x?: Tensor, y?: Tensor, sampleWeight?: Tensor, batchSize?: number\n) {\n  throw new NotImplementedError();\n}\n\nfunction trainValidationSplit(arrays: Tensor[], validationSplit: number) {\n  throw new NotImplementedError();\n}\n\n/**\n * A model which allows automatically applying preprocessing.\n */\nexport interface PipelineModelArgs extends ContainerArgs {\n  /**\n   * Defaults to true.\n   */\n  includePreprocessing?: boolean;\n}\n\nexport class PipelineModel extends LayersModel {\n  /** @nocollapse */\n  static override className = 'PipelineModel';\n\n  protected includePreprocessing: boolean;\n\n  constructor(args: PipelineModelArgs) {\n    super(args);\n    this.includePreprocessing = args.includePreprocessing ?? true;\n  }\n\n  /**\n   * An overridable function which preprocesses features.\n   */\n  preprocessFeatures(x: Tensor) {\n    return x;\n  }\n\n  /**\n   * An overridable function which preprocesses labels.\n   */\n  preprocessLabels(y: Tensor) {\n    return y;\n  }\n\n  /**\n   * An overridable function which preprocesses entire samples.\n   */\n  preprocessSamples(x: Tensor, y?: Tensor, sampleWeight?: Tensor):\n    Tensor\n    | [Tensor, Tensor]\n    | [Tensor, Tensor, Tensor] {\n    throw new NotImplementedError();\n  }\n\n  // ---------------------------------------------------------------------------\n  // Below are overrides to LayersModel methods to apply the functions above.\n  // ---------------------------------------------------------------------------\n  override fit(\n    x: Tensor|Tensor[]|{[inputName: string]: Tensor},\n    y: Tensor|Tensor[]|{[inputName: string]: Tensor},\n    args: ModelFitArgs = {}\n  ): Promise<History> {\n    throw new NotImplementedError(\n      `Uses ${convertInputsToDataset}, ${trainValidationSplit} ` +\n      `${packXYSampleWeight}, and ${unPackXYSampleWeight}`);\n  }\n\n  override evaluate(\n    x: Tensor|Tensor[],\n    y: Tensor|Tensor[],\n    args?: ModelEvaluateArgs\n  ): Scalar | Scalar[] {\n    throw new NotImplementedError();\n  }\n\n  override predict(\n    x: Tensor | Tensor[],\n    args?: ModelPredictConfig\n  ): Tensor | Tensor[] {\n    throw new NotImplementedError();\n  }\n\n  override trainOnBatch(\n    x: Tensor|Tensor[]|{[inputName: string]: Tensor},\n    y: Tensor|Tensor[]|{[inputName: string]: Tensor},\n    sampleWeight?: Tensor\n  ): Promise<number|number[]> {\n    throw new NotImplementedError();\n  }\n\n  override predictOnBatch(x: Tensor|Tensor[]): Tensor|Tensor[] {\n    throw new NotImplementedError();\n  }\n}\n"]}

@@ -36,3 +36,3 @@ /**

*/
export declare function toList(x: any): any[];
export declare function toList<T>(x: T | T[]): T[];
/**

@@ -39,0 +39,0 @@ * Generate a UID for a list

@@ -517,2 +517,2 @@ /**

}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"generic_utils.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/utils/generic_utils.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,6CAA6C;AAE7C,OAAO,EAAiC,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAE3E,OAAO,EAAC,cAAc,EAAE,UAAU,EAAC,MAAM,WAAW,CAAC;AAErD,gBAAgB;AAEhB;;;GAGG;AACH,kCAAkC;AAClC,MAAM,UAAU,YAAY,CAAC,KAAU,EAAE,SAAiB;IACxD,IAAI,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;QACxB,kCAAkC;QAClC,IAAI,QAAQ,GAAU,EAAE,CAAC;QACzB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE;YAClC,QAAQ,GAAG,QAAQ,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;SACnC;QACD,OAAO,QAAQ,CAAC;KACjB;SAAM;QACL,MAAM,QAAQ,GAAG,IAAI,KAAK,CAAC,SAAS,CAAC,CAAC;QACtC,QAAQ,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QACrB,OAAO,QAAQ,CAAC;KACjB;AACH,CAAC;AAED,MAAM,UAAU,MAAM,CAAC,GAAY,EAAE,OAAgB;IACnD,IAAI,CAAC,GAAG,EAAE;QACR,MAAM,IAAI,cAAc,CAAC,OAAO,CAAC,CAAC;KACnC;AACH,CAAC;AAED;;GAEG;AACH,MAAM,UAAU,KAAK,CAAI,KAAU,EAAE,QAAW;IAC9C,IAAI,OAAO,GAAG,CAAC,CAAC;IAChB,KAAK,MAAM,IAAI,IAAI,KAAK,EAAE;QACxB,IAAI,IAAI,KAAK,QAAQ,EAAE;YACrB,OAAO,EAAE,CAAC;SACX;KACF;IACD,OAAO,OAAO,CAAC;AACjB,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,gBAAgB,CAAI,EAAO;IACzC,IAAI,EAAE,CAAC,MAAM,KAAK,CAAC,EAAE;QACnB,OAAO,EAAE,CAAC,CAAC,CAAC,CAAC;KACd;IACD,OAAO,EAAE,CAAC;AACZ,CAAC;AAED;;;;;;;GAOG;AACH,kCAAkC;AAClC,MAAM,UAAU,MAAM,CAAC,CAAM;IAC3B,IAAI,KAAK,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE;QACpB,OAAO,CAAC,CAAC;KACV;IACD,OAAO,CAAC,CAAC,CAAC,CAAC;AACb,CAAC;AAED;;GAEG;AACH,kCAAkC;AAClC,MAAM,UAAU,aAAa,CAAC,IAAe;IAC3C,MAAM,UAAU,GAAG,MAAM,CAAC,IAAI,CAAC,CAAC;IAChC,IAAI,MAAM,GAAG,EAAE,CAAC;IAChB,KAAK,MAAM,GAAG,IAAI,UAAU,EAAE;QAC5B,IAAI,GAAG,CAAC,EAAE,IAAI,IAAI,EAAE;YAClB,MAAM,IAAI,UAAU,CAChB,UAAU,GAAG,wCAAwC,CAAC,CAAC;SAC5D;QACD,IAAI,MAAM,KAAK,EAAE,EAAE;YACjB,MAAM,GAAG,MAAM,GAAG,IAAI,CAAC;SACxB;QACD,MAAM,GAAG,GAAG,MAAM,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC;KACzC;IACD,OAAO,MAAM,CAAC;AAChB,CAAC;AACD;;;GAGG;AACH,MAAM,UAAU,WAAW,CAAC,IAAY;IACtC,MAAM,YAAY,GAAG,IAAI,CAAC,OAAO,CAAC,sBAAsB,EAAE,OAAO,CAAC,CAAC;IACnE,MAAM,QAAQ,GACV,YAAY,CAAC,OAAO,CAAC,iBAAiB,EAAE,OAAO,CAAC,CAAC,WAAW,EAAE,CAAC;IACnE;;;OAGG;IACH,IAAI,QAAQ,CAAC,CAAC,CAAC,KAAK,GAAG,EAAE;QACvB,OAAO,QAAQ,CAAC;KACjB;IACD,OAAO,SAAS,GAAG,QAAQ,CAAC;AAC9B,CAAC;AAED,MAAM,UAAU,WAAW,CAAC,UAAkB;IAC5C,4DAA4D;IAC5D,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE;QAC1B,OAAO,UAAU,CAAC;KACnB;IACD,iDAAiD;IACjD,IAAI,UAAU,CAAC,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,EAAE;QAClC,OAAO,UAAU,CAAC;KACnB;IACD,OAAO,UAAU,CAAC,OAAO,CAAC,aAAa,EAAE,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,WAAW,EAAE,CAAC,CAAC;AACxE,CAAC;AAED,kCAAkC;AAClC,IAAI,sBAAsB,GAAG,EAA8B,CAAC;AAE5D,MAAM,UAAU,oBAAoB,CAAC,QAAoC;IAEvE,IAAI,QAAQ,KAAK,IAAI,IAAI,QAAQ,KAAK,SAAS,EAAE;QAC/C,OAAO,IAAI,CAAC;KACb;IACD,MAAM,IAAI,GAAkC,EAAE,CAAC;IAC/C,IAAI,CAAC,WAAW,CAAC,GAAG,QAAQ,CAAC,YAAY,EAAE,CAAC;IAC5C,IAAI,CAAC,QAAQ,CAAC,GAAG,QAAQ,CAAC,SAAS,EAAE,CAAC;IACtC,OAAO,IAAI,CAAC;AACd,CAAC;AAED;;;;;;;;;;GAUG;AACH,SAAS,6BAA6B,CAAC,MAAqC;IAE1E,IAAI,MAAM,IAAI,IAAI,IAAI,OAAO,MAAM,KAAK,QAAQ,EAAE;QAChD,OAAO;KACR;SAAM,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;QAChC,MAAM,CAAC,OAAO,CAAC,UAAU,CAAC,EAAE,CAAC,6BAA6B,CAAC,UAAU,CAAC,CAAC,CAAC;KACzE;SAAM;QACL,MAAM,MAAM,GAAG,MAAM,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;QACnC,KAAK,MAAM,KAAK,IAAI,MAAM,EAAE;YAC1B,MAAM,KAAK,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC;YAC5B,IAAI,KAAK,IAAI,IAAI,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE;gBAC9C,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,IAAI,KAAK,CAAC,MAAM,CAAC,KAAK,SAAS;oBACpD,OAAO,KAAK,CAAC,OAAO,CAAC,KAAK,QAAQ,EAAE;oBACtC,MAAM,CAAC,KAAK,CAAC,GAAG,KAAK,CAAC,OAAO,CAAC,CAAC;iBAChC;qBAAM;oBACL,6BAA6B,CAAC,KAAiC,CAAC,CAAC;iBAClE;aACF;SACF;KACF;AACH,CAAC;AAED;;;;;;;;;;;GAWG;AACH,wBAAwB;AACxB,MAAM,UAAU,sBAAsB,CAClC,UAA2C,EAC3C,gBAAgB,EAA8B,EAC9C,gBAAgB,EAA8B,EAC9C,mBAAmB,GAAG,QAAQ,EAAE,cAAc,GAAG,KAAK;IACxD,gBAAgB;IAChB,IAAI,OAAO,UAAU,KAAK,QAAQ,EAAE;QAClC,MAAM,YAAY,GAAG,UAAU,CAAC;QAChC,IAAI,EAAE,CAAC;QACP,IAAI,YAAY,IAAI,aAAa,EAAE;YACjC,EAAE,GAAG,aAAa,CAAC,YAAY,CAAC,CAAC;SAClC;aAAM,IAAI,YAAY,IAAI,sBAAsB,EAAE;YACjD,EAAE,GAAG,sBAAsB,CAAC,YAAY,CAAC,CAAC;SAC3C;aAAM;YACL,EAAE,GAAG,aAAa,CAAC,YAAY,CAAC,CAAC;YACjC,IAAI,EAAE,IAAI,IAAI,EAAE;gBACd,MAAM,IAAI,UAAU,CAChB,WAAW,mBAAmB,KAAK,UAAU,IAAI;oBACjD,oDAAoD;oBACpD,UAAU,mBAAmB,kCAAkC;oBAC/D,iEAAiE;oBACjE,SAAS;oBACT,iBAAiB,mBAAmB,6BAA6B;oBACjE,sCAAsC;oBACtC,mCAAmC,CAAC,CAAC;gBACzC,0DAA0D;aAC3D;SACF;QACD,OAAO,EAAE,CAAC;KACX;SAAM;QACL,8DAA8D;QAC9D,MAAM,MAAM,GAAG,UAAU,CAAC;QAC1B,IAAI,MAAM,CAAC,WAAW,CAAC,IAAI,IAAI,IAAI,MAAM,CAAC,QAAQ,CAAC,IAAI,IAAI,EAAE;YAC3D,MAAM,IAAI,UAAU,CAChB,GAAG,mBAAmB,4BAA4B;gBAClD,GAAG,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,KAAK;gBAC9B,oCAAoC,CAAC,CAAC;SAC3C;QACD,MAAM,SAAS,GAAG,MAAM,CAAC,WAAW,CAAW,CAAC;QAChD,IAAI,GAAG,EAAE,UAAU,CAAC;QACpB,IAAI,SAAS,IAAI,aAAa,EAAE;YAC9B,CAAC,GAAG,EAAE,UAAU,CAAC,GAAG,aAAa,CAAC,SAAS,CAAC,CAAC;SAC9C;aAAM,IAAI,SAAS,IAAI,sBAAsB,EAAE;YAC9C,CAAC,GAAG,EAAE,UAAU,CAAC,GAAG,sBAAsB,CAAC,WAAW,CAAC,CAAC;SACzD;aAAM,IAAI,SAAS,IAAI,aAAa,EAAE;YACrC,CAAC,GAAG,EAAE,UAAU,CAAC,GAAG,aAAa,CAAC,SAAS,CAAC,CAAC;SAC9C;QACD,IAAI,GAAG,IAAI,IAAI,EAAE;YACf,MAAM,IAAI,UAAU,CAChB,WAAW,mBAAmB,KAAK,SAAS,IAAI;gBAChD,oDAAoD;gBACpD,UAAU,mBAAmB,kCAAkC;gBAC/D,iEAAiE;gBACjE,SAAS;gBACT,iBAAiB,mBAAmB,6BAA6B;gBACjE,sCAAsC;gBACtC,mCAAmC,CAAC,CAAC;YACzC,0DAA0D;SAC3D;QACD,IAAI,UAAU,IAAI,IAAI,EAAE;YACtB,uEAAuE;YACvE,wEAAwE;YACxE,0EAA0E;YAC1E,gBAAgB;YAEhB,kCAAkC;YAClC,MAAM,qBAAqB,GAAG,EAA8B,CAAC;YAC7D,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,IAAI,CAAC,sBAAsB,CAAC,EAAE;gBACrD,qBAAqB,CAAC,GAAG,CAAC,GAAG,sBAAsB,CAAC,GAAG,CAAC,CAAC;aAC1D;YACD,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,IAAI,CAAC,aAAa,CAAC,EAAE;gBAC5C,qBAAqB,CAAC,GAAG,CAAC,GAAG,aAAa,CAAC,GAAG,CAAC,CAAC;aACjD;YACD,kCAAkC;YAClC,MAAM,YAAY,GAAG,MAAM,CAAC,QAAQ,CAA6B,CAAC;YAClE,YAAY,CAAC,eAAe,CAAC,GAAG,qBAAqB,CAAC;YAEtD,MAAM,mBAAmB,qBAAO,sBAAsB,CAAC,CAAC;YACxD,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,IAAI,CAAC,aAAa,CAAC,EAAE;gBAC5C,sBAAsB,CAAC,GAAG,CAAC,GAAG,aAAa,CAAC,GAAG,CAAC,CAAC;aAClD;YACD,6BAA6B,CAAC,MAAM,CAAC,QAAQ,CAAC,CAAC,CAAC;YAChD,MAAM,SAAS,GACX,UAAU,CAAC,GAAG,EAAE,MAAM,CAAC,QAAQ,CAAC,EAAE,aAAa,EAAE,cAAc,CAAC,CAAC;YACrE,sBAAsB,qBAAO,mBAAmB,CAAC,CAAC;YAElD,OAAO,SAAS,CAAC;SAClB;aAAM;YACL,kDAAkD;YAClD,4CAA4C;YAC5C,8BAA8B;YAC9B,MAAM,mBAAmB,qBAAO,sBAAsB,CAAC,CAAC;YACxD,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,IAAI,CAAC,aAAa,CAAC,EAAE;gBAC5C,sBAAsB,CAAC,GAAG,CAAC,GAAG,aAAa,CAAC,GAAG,CAAC,CAAC;aAClD;YACD,mEAAmE;YACnE,iEAAiE;YACjE,oEAAoE;YACpE,MAAM,SAAS,GAAG,IAAI,GAAG,CAAC,MAAM,CAAC,QAAQ,CAAC,CAAC,CAAC;YAC5C,sBAAsB,qBAAO,mBAAmB,CAAC,CAAC;YAClD,OAAO,SAAS,CAAC;SAClB;KACF;AACH,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,aAAa,CAAC,CAAS,EAAE,CAAS;IAChD,OAAO,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;AAC1C,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,oBAAoB,CAAC,CAAS,EAAE,CAAS;IACvD,OAAO,CAAC,CAAC,GAAG,aAAa,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;AAClC,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,aAAa,CAAC,KAAa;IACzC,QAAQ,KAAK,EAAE;QACb,KAAK,SAAS;YACZ,OAAO,SAAS,CAAC;QACnB;YACE,MAAM,IAAI,UAAU,CAAC,kBAAkB,KAAK,EAAE,CAAC,CAAC;KACnD;AACH,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,YAAY,CAAC,EAAY,EAAE,EAAY;IACrD,IAAI,EAAE,IAAI,IAAI,IAAI,EAAE,IAAI,IAAI,EAAE;QAC5B,OAAO,EAAE,KAAK,EAAE,CAAC;KAClB;IACD,IAAI,EAAE,CAAC,MAAM,KAAK,EAAE,CAAC,MAAM,EAAE;QAC3B,OAAO,KAAK,CAAC;KACd;IACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QAClC,IAAI,EAAE,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,EAAE;YACnB,OAAO,KAAK,CAAC;SACd;KACF;IACD,OAAO,IAAI,CAAC;AACd,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,MAAM,CAAI,EAAO;IAC/B,IAAI,EAAE,IAAI,IAAI,EAAE;QACd,OAAO,EAAE,CAAC;KACX;IACD,MAAM,GAAG,GAAQ,EAAE,CAAC;IACpB,oDAAoD;IACpD,KAAK,MAAM,CAAC,IAAI,EAAE,EAAE;QAClB,IAAI,GAAG,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,EAAE;YACzB,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;SACb;KACF;IACD,OAAO,GAAG,CAAC;AACb,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,aAAa,CAAC,GAAO;IACnC,IAAI,GAAG,IAAI,IAAI,EAAE;QACf,MAAM,IAAI,UAAU,CAAC,yBAAyB,IAAI,CAAC,SAAS,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;KACtE;IACD,KAAK,MAAM,GAAG,IAAI,GAAG,EAAE;QACrB,IAAI,GAAG,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE;YAC3B,OAAO,KAAK,CAAC;SACd;KACF;IACD,OAAO,IAAI,CAAC;AACd,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,yBAAyB,CACrC,MAAgB,EAAE,KAAa,EAAE,KAAa;IAChD,IAAI,KAAK,IAAI,IAAI,EAAE;QACjB,OAAO;KACR;IACD,IAAI,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,GAAG,CAAC,EAAE;QAC7B,MAAM,IAAI,UAAU,CAAC,GAAG,KAAK,mBAAmB,KAAK,uBACjD,MAAM,qBAAqB,CAAC,CAAC;KAClC;AACH,CAAC;AAED;;;;;;;;;;;;;GAaG;AACH,wBAAwB;AACxB,MAAM,UAAU,uBAAuB,CACnC,CAAM,EAAE,YAAoB,EAAE,SAAS,GAAG,CAAC,EAC3C,SAAS,GAAG,QAAQ;IACtB,MAAM,CAAC,SAAS,IAAI,CAAC,CAAC,CAAC;IACvB,MAAM,CAAC,SAAS,IAAI,SAAS,CAAC,CAAC;IAC/B,OAAO,CACH,KAAK,CAAC,OAAO,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,MAAM,IAAI,SAAS,IAAI,CAAC,CAAC,MAAM,IAAI,SAAS;QAClE,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,OAAO,CAAC,KAAK,YAAY,CAAC,CAAC,CAAC;AAC/C,CAAC;AACD,uBAAuB;AAEvB;;;;;;GAMG;AACH,MAAM,UAAU,qBAAqB,CAAC,KAAsB,EAAE,IAAY;IACxE,IAAI,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;QACxB,IAAI,CAAC,MAAM,CACP,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,GAAG,EAAE,CAAC,GAAG,IAAI,kCAAkC,CAAC,CAAC;QACvE,KAAK,CAAC,OAAO,CACT,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,qBAAqB,CAAC,CAAC,EAAE,WAAW,CAAC,GAAG,CAAC,OAAO,IAAI,EAAE,CAAC,CAAC,CAAC;KACxE;SAAM;QACL,IAAI,CAAC,MAAM,CACP,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,IAAI,KAAK,GAAG,CAAC,EACpC,GAAG,EAAE,CAAC,YAAY,IAAI,qCAAqC;YACvD,GAAG,sBAAsB,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC;KAC9C;AACH,CAAC;AAED;;;;;;;;;GASG;AACH,kCAAkC;AAClC,MAAM,UAAU,sBAAsB,CAAC,KAAU;IAC/C,IAAI,KAAK,KAAK,IAAI,EAAE;QAClB,OAAO,MAAM,CAAC;KACf;SAAM,IAAI,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;QAC/B,OAAO,GAAG,GAAG,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,sBAAsB,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,GAAG,GAAG,CAAC;KACxE;SAAM,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE;QACpC,OAAO,IAAI,KAAK,GAAG,CAAC;KACrB;SAAM;QACL,OAAO,GAAG,KAAK,EAAE,CAAC;KACnB;AACH,CAAC;AAED;;;;;;;;GAQG;AACH,MAAM,UAAU,QAAQ,CACpB,CAA4B,EAAE,MAAc,EAC5C,OAAkB;IACpB,IAAI,QAAQ,GAAG,OAAO,IAAI,IAAI,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC;IACxD,IAAI,UAAa,CAAC;IAClB,MAAM,EAAE,GAAG,CAAC,GAAG,IAAe,EAAE,EAAE;QAChC,MAAM,GAAG,GAAG,OAAO,IAAI,IAAI,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC;QACrD,IAAI,GAAG,GAAG,QAAQ,GAAG,MAAM,EAAE;YAC3B,OAAO,UAAU,CAAC;SACnB;QACD,QAAQ,GAAG,GAAG,CAAC;QACf,UAAU,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,CAAC;QACxB,OAAO,UAAU,CAAC;IACpB,CAAC,CAAC;IACF,OAAO,EAAE,CAAC;AACZ,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,0BAA0B,CAAC,cAAsB;IAE/D,IAAI,cAAc,KAAK,MAAM,EAAE;QAC7B,OAAO,MAAM,CAAC;KACf;IACD,IAAI,cAAc,KAAK,QAAQ,EAAE;QAC/B,OAAO,QAAQ,CAAC;KACjB;IACD,IAAI,cAAc,KAAK,KAAK,EAAE;QAC5B,OAAO,KAAK,CAAC;KACd;IACD,OAAO,IAAI,CAAC;AACd,CAAC;AAID;;;;;;;;;;;;;;GAcG;AACH,MAAM,UAAU,2BAA2B,CAAC,GAAG,aAA6B;IAE1E,MAAM,CAAC,aAAa,CAAC,MAAM,GAAG,CAAC,EAAE,wBAAwB,CAAC,CAAC;IAE3D,KAAK,MAAM,MAAM,IAAI,aAAa,EAAE;QAClC,MAAM,CAAC,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,mCAAmC,CAAC,CAAC;QACnE,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,EAAE,4BAA4B,CAAC,CAAC;KACzD;IAED,OAAO,aAAa,CAAC,MAAM,CAAC,CAAC,QAAQ,EAAE,MAAM,EAAE,EAAE;QAC/C,IAAI,QAAQ,CAAC,MAAM,KAAK,CAAC,EAAE;YACzB,OAAO,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC;SACrC;QAED,OAAO,MAAM;aACR,GAAG,CAAC,KAAK,CAAC,EAAE;YACX,OAAO,QAAQ,CAAC,GAAG,CAAC,CAAC,SAAS,EAAE,EAAE,CAAC,CAAC,GAAG,SAAS,EAAE,KAAK,CAAC,CAAC,CAAC;QAC5D,CAAC,CAAC;aACD,MAAM,CAAC,CAAC,gBAAgB,EAAE,kBAAkB,EAAE,EAAE;YAC/C,OAAO,gBAAgB,CAAC,MAAM,CAAC,kBAAkB,CAAC,CAAC;QACrD,CAAC,EAAE,EAAE,CAAC,CAAC;IACb,CAAC,EAAE,EAAoB,CAAC,CAAC;AAC3B,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/* Original source: utils/generic_utils.py */\n\nimport {DataType, fused, serialization, util} from '@tensorflow/tfjs-core';\n\nimport {AssertionError, ValueError} from '../errors';\n\n// tslint:enable\n\n/**\n * If `value` is an Array, equivalent to Python's `value * numValues`.\n * If `value` is not an Array, equivalent to Python's `[value] * numValues`\n */\n// tslint:disable-next-line:no-any\nexport function pyListRepeat(value: any, numValues: number): any[] {\n  if (Array.isArray(value)) {\n    // tslint:disable-next-line:no-any\n    let newArray: any[] = [];\n    for (let i = 0; i < numValues; i++) {\n      newArray = newArray.concat(value);\n    }\n    return newArray;\n  } else {\n    const newArray = new Array(numValues);\n    newArray.fill(value);\n    return newArray;\n  }\n}\n\nexport function assert(val: boolean, message?: string): void {\n  if (!val) {\n    throw new AssertionError(message);\n  }\n}\n\n/**\n * Count the number of elements of the `array` that are equal to `reference`.\n */\nexport function count<T>(array: T[], refernce: T) {\n  let counter = 0;\n  for (const item of array) {\n    if (item === refernce) {\n      counter++;\n    }\n  }\n  return counter;\n}\n\n/**\n * If an array is of length 1, just return the first element. Otherwise, return\n * the full array.\n * @param tensors\n */\nexport function singletonOrArray<T>(xs: T[]): T|T[] {\n  if (xs.length === 1) {\n    return xs[0];\n  }\n  return xs;\n}\n\n/**\n * Normalizes a list/tensor into a list.\n *\n * If a tensor is passed, we return\n * a list of size 1 containing the tensor.\n *\n * @param x target object to be normalized.\n */\n// tslint:disable-next-line:no-any\nexport function toList(x: any): any[] {\n  if (Array.isArray(x)) {\n    return x;\n  }\n  return [x];\n}\n\n/**\n * Generate a UID for a list\n */\n// tslint:disable-next-line:no-any\nexport function objectListUid(objs: any|any[]): string {\n  const objectList = toList(objs);\n  let retVal = '';\n  for (const obj of objectList) {\n    if (obj.id == null) {\n      throw new ValueError(\n          `Object ${obj} passed to objectListUid without an id`);\n    }\n    if (retVal !== '') {\n      retVal = retVal + ', ';\n    }\n    retVal = `${retVal}${Math.abs(obj.id)}`;\n  }\n  return retVal;\n}\n/**\n * Converts string to snake-case.\n * @param name\n */\nexport function toSnakeCase(name: string): string {\n  const intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, '$1_$2');\n  const insecure =\n      intermediate.replace(/([a-z])([A-Z])/g, '$1_$2').toLowerCase();\n  /*\n   If the class is private the name starts with \"_\" which is not secure\n   for creating scopes. We prefix the name with \"private\" in this case.\n   */\n  if (insecure[0] !== '_') {\n    return insecure;\n  }\n  return 'private' + insecure;\n}\n\nexport function toCamelCase(identifier: string): string {\n  // quick return for empty string or single character strings\n  if (identifier.length <= 1) {\n    return identifier;\n  }\n  // Check for the underscore indicating snake_case\n  if (identifier.indexOf('_') === -1) {\n    return identifier;\n  }\n  return identifier.replace(/[_]+(\\w|$)/g, (m, p1) => p1.toUpperCase());\n}\n\n// tslint:disable-next-line:no-any\nlet _GLOBAL_CUSTOM_OBJECTS = {} as {[objName: string]: any};\n\nexport function serializeKerasObject(instance: serialization.Serializable):\n    serialization.ConfigDictValue {\n  if (instance === null || instance === undefined) {\n    return null;\n  }\n  const dict: serialization.ConfigDictValue = {};\n  dict['className'] = instance.getClassName();\n  dict['config'] = instance.getConfig();\n  return dict;\n}\n\n/**\n * Replace ndarray-style scalar objects in serialization objects with numbers.\n *\n * Background: In some versions of tf.keras, certain scalar values in the HDF5\n * model save file can be serialized as: `{'type': 'ndarray', 'value': num}`,\n * where in `num` is a plain number. This method converts such serialization\n * to a `number`.\n *\n * @param config The keras-format serialization object to be processed\n *   (in place).\n */\nfunction convertNDArrayScalarsInConfig(config: serialization.ConfigDictValue):\n    void {\n  if (config == null || typeof config !== 'object') {\n    return;\n  } else if (Array.isArray(config)) {\n    config.forEach(configItem => convertNDArrayScalarsInConfig(configItem));\n  } else {\n    const fields = Object.keys(config);\n    for (const field of fields) {\n      const value = config[field];\n      if (value != null && typeof value === 'object') {\n        if (!Array.isArray(value) && value['type'] === 'ndarray' &&\n            typeof value['value'] === 'number') {\n          config[field] = value['value'];\n        } else {\n          convertNDArrayScalarsInConfig(value as serialization.ConfigDict);\n        }\n      }\n    }\n  }\n}\n\n/**\n * Deserialize a saved Keras Object\n * @param identifier either a string ID or a saved Keras dictionary\n * @param moduleObjects a list of Python class names to object constructors\n * @param customObjects a list of Python class names to object constructors\n * @param printableModuleName debug text for the object being reconstituted\n * @param fastWeightInit Optional flag to use fast weight initialization\n *   during deserialization. This is applicable to cases in which\n *   the initialization will be immediately overwritten by loaded weight\n *   values. Default: `false`.\n * @returns a TensorFlow.js Layers object\n */\n// tslint:disable:no-any\nexport function deserializeKerasObject(\n    identifier: string|serialization.ConfigDict,\n    moduleObjects = {} as {[objName: string]: any},\n    customObjects = {} as {[objName: string]: any},\n    printableModuleName = 'object', fastWeightInit = false): any {\n  // tslint:enable\n  if (typeof identifier === 'string') {\n    const functionName = identifier;\n    let fn;\n    if (functionName in customObjects) {\n      fn = customObjects[functionName];\n    } else if (functionName in _GLOBAL_CUSTOM_OBJECTS) {\n      fn = _GLOBAL_CUSTOM_OBJECTS[functionName];\n    } else {\n      fn = moduleObjects[functionName];\n      if (fn == null) {\n        throw new ValueError(\n            `Unknown ${printableModuleName}: ${identifier}. ` +\n            `This may be due to one of the following reasons:\\n` +\n            `1. The ${printableModuleName} is defined in Python, in which ` +\n            `case it needs to be ported to TensorFlow.js or your JavaScript ` +\n            `code.\\n` +\n            `2. The custom ${printableModuleName} is defined in JavaScript, ` +\n            `but is not registered properly with ` +\n            `tf.serialization.registerClass().`);\n        // TODO(cais): Add link to tutorial page on custom layers.\n      }\n    }\n    return fn;\n  } else {\n    // In this case we are dealing with a Keras config dictionary.\n    const config = identifier;\n    if (config['className'] == null || config['config'] == null) {\n      throw new ValueError(\n          `${printableModuleName}: Improper config format: ` +\n          `${JSON.stringify(config)}.\\n` +\n          `'className' and 'config' must set.`);\n    }\n    const className = config['className'] as string;\n    let cls, fromConfig;\n    if (className in customObjects) {\n      [cls, fromConfig] = customObjects[className];\n    } else if (className in _GLOBAL_CUSTOM_OBJECTS) {\n      [cls, fromConfig] = _GLOBAL_CUSTOM_OBJECTS['className'];\n    } else if (className in moduleObjects) {\n      [cls, fromConfig] = moduleObjects[className];\n    }\n    if (cls == null) {\n      throw new ValueError(\n          `Unknown ${printableModuleName}: ${className}. ` +\n          `This may be due to one of the following reasons:\\n` +\n          `1. The ${printableModuleName} is defined in Python, in which ` +\n          `case it needs to be ported to TensorFlow.js or your JavaScript ` +\n          `code.\\n` +\n          `2. The custom ${printableModuleName} is defined in JavaScript, ` +\n          `but is not registered properly with ` +\n          `tf.serialization.registerClass().`);\n      // TODO(cais): Add link to tutorial page on custom layers.\n    }\n    if (fromConfig != null) {\n      // Porting notes: Instead of checking to see whether fromConfig accepts\n      // customObjects, we create a customObjects dictionary and tack it on to\n      // config['config'] as config['config'].customObjects. Objects can use it,\n      // if they want.\n\n      // tslint:disable-next-line:no-any\n      const customObjectsCombined = {} as {[objName: string]: any};\n      for (const key of Object.keys(_GLOBAL_CUSTOM_OBJECTS)) {\n        customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key];\n      }\n      for (const key of Object.keys(customObjects)) {\n        customObjectsCombined[key] = customObjects[key];\n      }\n      // Add the customObjects to config\n      const nestedConfig = config['config'] as serialization.ConfigDict;\n      nestedConfig['customObjects'] = customObjectsCombined;\n\n      const backupCustomObjects = {..._GLOBAL_CUSTOM_OBJECTS};\n      for (const key of Object.keys(customObjects)) {\n        _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];\n      }\n      convertNDArrayScalarsInConfig(config['config']);\n      const returnObj =\n          fromConfig(cls, config['config'], customObjects, fastWeightInit);\n      _GLOBAL_CUSTOM_OBJECTS = {...backupCustomObjects};\n\n      return returnObj;\n    } else {\n      // Then `cls` may be a function returning a class.\n      // In this case by convention `config` holds\n      // the kwargs of the function.\n      const backupCustomObjects = {..._GLOBAL_CUSTOM_OBJECTS};\n      for (const key of Object.keys(customObjects)) {\n        _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];\n      }\n      // In python this is **config['config'], for tfjs-layers we require\n      // classes that use this fall-through construction method to take\n      // a config interface that mimics the expansion of named parameters.\n      const returnObj = new cls(config['config']);\n      _GLOBAL_CUSTOM_OBJECTS = {...backupCustomObjects};\n      return returnObj;\n    }\n  }\n}\n\n/**\n * Compares two numbers for sorting.\n * @param a\n * @param b\n */\nexport function numberCompare(a: number, b: number) {\n  return (a < b) ? -1 : ((a > b) ? 1 : 0);\n}\n\n/**\n * Comparison of two numbers for reverse sorting.\n * @param a\n * @param b\n */\nexport function reverseNumberCompare(a: number, b: number) {\n  return -1 * numberCompare(a, b);\n}\n\n/**\n * Convert a string into the corresponding DType.\n * @param dtype\n * @returns An instance of DType.\n */\nexport function stringToDType(dtype: string): DataType {\n  switch (dtype) {\n    case 'float32':\n      return 'float32';\n    default:\n      throw new ValueError(`Invalid dtype: ${dtype}`);\n  }\n}\n\n/**\n * Test the element-by-element equality of two Arrays of strings.\n * @param xs First array of strings.\n * @param ys Second array of strings.\n * @returns Wether the two arrays are all equal, element by element.\n */\nexport function stringsEqual(xs: string[], ys: string[]): boolean {\n  if (xs == null || ys == null) {\n    return xs === ys;\n  }\n  if (xs.length !== ys.length) {\n    return false;\n  }\n  for (let i = 0; i < xs.length; ++i) {\n    if (xs[i] !== ys[i]) {\n      return false;\n    }\n  }\n  return true;\n}\n\n/**\n * Get the unique elements of an array.\n * @param xs Array.\n * @returns An Array consisting of the unique elements in `xs`.\n */\nexport function unique<T>(xs: T[]): T[] {\n  if (xs == null) {\n    return xs;\n  }\n  const out: T[] = [];\n  // TODO(cais): Maybe improve performance by sorting.\n  for (const x of xs) {\n    if (out.indexOf(x) === -1) {\n      out.push(x);\n    }\n  }\n  return out;\n}\n\n/**\n * Determine if an Object is empty (i.e., does not have own properties).\n * @param obj Object\n * @returns Whether the Object is empty.\n * @throws ValueError: If object is `null` or `undefined`.\n */\nexport function isObjectEmpty(obj: {}): boolean {\n  if (obj == null) {\n    throw new ValueError(`Invalid value in obj: ${JSON.stringify(obj)}`);\n  }\n  for (const key in obj) {\n    if (obj.hasOwnProperty(key)) {\n      return false;\n    }\n  }\n  return true;\n}\n\n/**\n * Helper function used to build type union/enum run-time checkers.\n * @param values The list of allowed values.\n * @param label A string name for the type\n * @param value The value to test.\n * @throws ValueError: If the value is not in values nor `undefined`/`null`.\n */\nexport function checkStringTypeUnionValue(\n    values: string[], label: string, value: string): void {\n  if (value == null) {\n    return;\n  }\n  if (values.indexOf(value) < 0) {\n    throw new ValueError(`${value} is not a valid ${label}.  Valid values are ${\n        values} or null/undefined.`);\n  }\n}\n\n/**\n * Helper function for verifying the types of inputs.\n *\n * Ensures that the elements of `x` are all of type `expectedType`.\n * Also verifies that the length of `x` is within bounds.\n *\n * @param x Object to test.\n * @param expectedType The string expected type of all of the elements in the\n * Array.\n * @param minLength Return false if x.length is less than this.\n * @param maxLength Return false if x.length is greater than this.\n * @returns true if and only if `x` is an `Array<expectedType>` with\n * length >= `minLength` and <= `maxLength`.\n */\n// tslint:disable:no-any\nexport function checkArrayTypeAndLength(\n    x: any, expectedType: string, minLength = 0,\n    maxLength = Infinity): boolean {\n  assert(minLength >= 0);\n  assert(maxLength >= minLength);\n  return (\n      Array.isArray(x) && x.length >= minLength && x.length <= maxLength &&\n      x.every(e => typeof e === expectedType));\n}\n// tslint:enable:no-any\n\n/**\n * Assert that a value or an array of value are positive integer.\n *\n * @param value The value being asserted on. May be a single number or an array\n *   of numbers.\n * @param name Name of the value, used to make the error message.\n */\nexport function assertPositiveInteger(value: number|number[], name: string) {\n  if (Array.isArray(value)) {\n    util.assert(\n        value.length > 0, () => `${name} is unexpectedly an empty array.`);\n    value.forEach(\n        (v, i) => assertPositiveInteger(v, `element ${i + 1} of ${name}`));\n  } else {\n    util.assert(\n        Number.isInteger(value) && value > 0,\n        () => `Expected ${name} to be a positive integer, but got ` +\n            `${formatAsFriendlyString(value)}.`);\n  }\n}\n\n/**\n * Format a value into a display-friendly, human-readable fashion.\n *\n * - `null` is formatted as `'null'`\n * - Strings are formated with flanking pair of quotes.\n * - Arrays are formatted with flanking pair of square brackets.\n *\n * @param value The value to display.\n * @return Formatted string.\n */\n// tslint:disable-next-line:no-any\nexport function formatAsFriendlyString(value: any): string {\n  if (value === null) {\n    return 'null';\n  } else if (Array.isArray(value)) {\n    return '[' + value.map(v => formatAsFriendlyString(v)).join(',') + ']';\n  } else if (typeof value === 'string') {\n    return `\"${value}\"`;\n  } else {\n    return `${value}`;\n  }\n}\n\n/**\n * Returns a function `f2` (decorator) which wraps the original function\n * `f`. `f2` guarantees that `f` can be called at most once\n * every `waitMs` ms. If `f2` is called more often, it will return\n * the last returned result of `f`.\n *\n * @param f The original function `f` to wrap.\n * @param waitMs The time between two consecutive calls to `f` in ms.\n */\nexport function debounce<T>(\n    f: (...args: Array<{}>) => T, waitMs: number,\n    nowFunc?: Function): (...args: Array<{}>) => T {\n  let lastTime = nowFunc != null ? nowFunc() : util.now();\n  let lastResult: T;\n  const f2 = (...args: Array<{}>) => {\n    const now = nowFunc != null ? nowFunc() : util.now();\n    if (now - lastTime < waitMs) {\n      return lastResult;\n    }\n    lastTime = now;\n    lastResult = f(...args);\n    return lastResult;\n  };\n  return f2;\n}\n\n/**\n * Returns the fusable activation given a layers identifier.\n *\n * @param activationName The layers identifier string.\n * @return The name of the fusable activation.\n */\nexport function mapActivationToFusedKernel(activationName: string):\n    fused.Activation {\n  if (activationName === 'relu') {\n    return 'relu';\n  }\n  if (activationName === 'linear') {\n    return 'linear';\n  }\n  if (activationName === 'elu') {\n    return 'elu';\n  }\n  return null;\n}\n\ntype PossibleValues = Array<Array<boolean|string|number>>;\n\n/**\n * Returns the cartesian product of sets of values.\n * This works the same as itertools.product in Python.\n *\n * Example:\n *\n * filters = [128, 256, 512]\n * paddings = ['same', 'valid']\n *\n * product = [ [128, 'same'], [128, 'valid'], [256, 'same'], [256, 'valid'],\n * [512, 'same'], [512, 'valid']]\n *\n * @param arrayOfValues List/array of values.\n * @return The cartesian product.\n */\nexport function getCartesianProductOfValues(...arrayOfValues: PossibleValues):\n    PossibleValues {\n  assert(arrayOfValues.length > 0, 'arrayOfValues is empty');\n\n  for (const values of arrayOfValues) {\n    assert(Array.isArray(values), 'one of the values is not an array');\n    assert(values.length > 0, 'one of the values is empty');\n  }\n\n  return arrayOfValues.reduce((products, values) => {\n    if (products.length === 0) {\n      return values.map(value => [value]);\n    }\n\n    return values\n        .map(value => {\n          return products.map((prevValue) => [...prevValue, value]);\n        })\n        .reduce((flattenedProduct, unflattenedProduct) => {\n          return flattenedProduct.concat(unflattenedProduct);\n        }, []);\n  }, [] as PossibleValues);\n}\n"]}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"generic_utils.js","sourceRoot":"","sources":["../../../../../../tfjs-layers/src/utils/generic_utils.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,6CAA6C;AAE7C,OAAO,EAAiC,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAE3E,OAAO,EAAC,cAAc,EAAE,UAAU,EAAC,MAAM,WAAW,CAAC;AAErD,gBAAgB;AAEhB;;;GAGG;AACH,kCAAkC;AAClC,MAAM,UAAU,YAAY,CAAC,KAAU,EAAE,SAAiB;IACxD,IAAI,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;QACxB,kCAAkC;QAClC,IAAI,QAAQ,GAAU,EAAE,CAAC;QACzB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE;YAClC,QAAQ,GAAG,QAAQ,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;SACnC;QACD,OAAO,QAAQ,CAAC;KACjB;SAAM;QACL,MAAM,QAAQ,GAAG,IAAI,KAAK,CAAC,SAAS,CAAC,CAAC;QACtC,QAAQ,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QACrB,OAAO,QAAQ,CAAC;KACjB;AACH,CAAC;AAED,MAAM,UAAU,MAAM,CAAC,GAAY,EAAE,OAAgB;IACnD,IAAI,CAAC,GAAG,EAAE;QACR,MAAM,IAAI,cAAc,CAAC,OAAO,CAAC,CAAC;KACnC;AACH,CAAC;AAED;;GAEG;AACH,MAAM,UAAU,KAAK,CAAI,KAAU,EAAE,QAAW;IAC9C,IAAI,OAAO,GAAG,CAAC,CAAC;IAChB,KAAK,MAAM,IAAI,IAAI,KAAK,EAAE;QACxB,IAAI,IAAI,KAAK,QAAQ,EAAE;YACrB,OAAO,EAAE,CAAC;SACX;KACF;IACD,OAAO,OAAO,CAAC;AACjB,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,gBAAgB,CAAI,EAAO;IACzC,IAAI,EAAE,CAAC,MAAM,KAAK,CAAC,EAAE;QACnB,OAAO,EAAE,CAAC,CAAC,CAAC,CAAC;KACd;IACD,OAAO,EAAE,CAAC;AACZ,CAAC;AAED;;;;;;;GAOG;AACH,kCAAkC;AAClC,MAAM,UAAU,MAAM,CAAI,CAAQ;IAChC,IAAI,KAAK,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE;QACpB,OAAO,CAAC,CAAC;KACV;IACD,OAAO,CAAC,CAAC,CAAC,CAAC;AACb,CAAC;AAED;;GAEG;AACH,kCAAkC;AAClC,MAAM,UAAU,aAAa,CAAC,IAAe;IAC3C,MAAM,UAAU,GAAG,MAAM,CAAC,IAAI,CAAC,CAAC;IAChC,IAAI,MAAM,GAAG,EAAE,CAAC;IAChB,KAAK,MAAM,GAAG,IAAI,UAAU,EAAE;QAC5B,IAAI,GAAG,CAAC,EAAE,IAAI,IAAI,EAAE;YAClB,MAAM,IAAI,UAAU,CAChB,UAAU,GAAG,wCAAwC,CAAC,CAAC;SAC5D;QACD,IAAI,MAAM,KAAK,EAAE,EAAE;YACjB,MAAM,GAAG,MAAM,GAAG,IAAI,CAAC;SACxB;QACD,MAAM,GAAG,GAAG,MAAM,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC;KACzC;IACD,OAAO,MAAM,CAAC;AAChB,CAAC;AACD;;;GAGG;AACH,MAAM,UAAU,WAAW,CAAC,IAAY;IACtC,MAAM,YAAY,GAAG,IAAI,CAAC,OAAO,CAAC,sBAAsB,EAAE,OAAO,CAAC,CAAC;IACnE,MAAM,QAAQ,GACV,YAAY,CAAC,OAAO,CAAC,iBAAiB,EAAE,OAAO,CAAC,CAAC,WAAW,EAAE,CAAC;IACnE;;;OAGG;IACH,IAAI,QAAQ,CAAC,CAAC,CAAC,KAAK,GAAG,EAAE;QACvB,OAAO,QAAQ,CAAC;KACjB;IACD,OAAO,SAAS,GAAG,QAAQ,CAAC;AAC9B,CAAC;AAED,MAAM,UAAU,WAAW,CAAC,UAAkB;IAC5C,4DAA4D;IAC5D,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE;QAC1B,OAAO,UAAU,CAAC;KACnB;IACD,iDAAiD;IACjD,IAAI,UAAU,CAAC,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,EAAE;QAClC,OAAO,UAAU,CAAC;KACnB;IACD,OAAO,UAAU,CAAC,OAAO,CAAC,aAAa,EAAE,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,EAAE,CAAC,WAAW,EAAE,CAAC,CAAC;AACxE,CAAC;AAED,kCAAkC;AAClC,IAAI,sBAAsB,GAAG,EAA8B,CAAC;AAE5D,MAAM,UAAU,oBAAoB,CAAC,QAAoC;IAEvE,IAAI,QAAQ,KAAK,IAAI,IAAI,QAAQ,KAAK,SAAS,EAAE;QAC/C,OAAO,IAAI,CAAC;KACb;IACD,MAAM,IAAI,GAAkC,EAAE,CAAC;IAC/C,IAAI,CAAC,WAAW,CAAC,GAAG,QAAQ,CAAC,YAAY,EAAE,CAAC;IAC5C,IAAI,CAAC,QAAQ,CAAC,GAAG,QAAQ,CAAC,SAAS,EAAE,CAAC;IACtC,OAAO,IAAI,CAAC;AACd,CAAC;AAED;;;;;;;;;;GAUG;AACH,SAAS,6BAA6B,CAAC,MAAqC;IAE1E,IAAI,MAAM,IAAI,IAAI,IAAI,OAAO,MAAM,KAAK,QAAQ,EAAE;QAChD,OAAO;KACR;SAAM,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;QAChC,MAAM,CAAC,OAAO,CAAC,UAAU,CAAC,EAAE,CAAC,6BAA6B,CAAC,UAAU,CAAC,CAAC,CAAC;KACzE;SAAM;QACL,MAAM,MAAM,GAAG,MAAM,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;QACnC,KAAK,MAAM,KAAK,IAAI,MAAM,EAAE;YAC1B,MAAM,KAAK,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC;YAC5B,IAAI,KAAK,IAAI,IAAI,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE;gBAC9C,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,IAAI,KAAK,CAAC,MAAM,CAAC,KAAK,SAAS;oBACpD,OAAO,KAAK,CAAC,OAAO,CAAC,KAAK,QAAQ,EAAE;oBACtC,MAAM,CAAC,KAAK,CAAC,GAAG,KAAK,CAAC,OAAO,CAAC,CAAC;iBAChC;qBAAM;oBACL,6BAA6B,CAAC,KAAiC,CAAC,CAAC;iBAClE;aACF;SACF;KACF;AACH,CAAC;AAED;;;;;;;;;;;GAWG;AACH,wBAAwB;AACxB,MAAM,UAAU,sBAAsB,CAClC,UAA2C,EAC3C,gBAAgB,EAA8B,EAC9C,gBAAgB,EAA8B,EAC9C,mBAAmB,GAAG,QAAQ,EAAE,cAAc,GAAG,KAAK;IACxD,gBAAgB;IAChB,IAAI,OAAO,UAAU,KAAK,QAAQ,EAAE;QAClC,MAAM,YAAY,GAAG,UAAU,CAAC;QAChC,IAAI,EAAE,CAAC;QACP,IAAI,YAAY,IAAI,aAAa,EAAE;YACjC,EAAE,GAAG,aAAa,CAAC,YAAY,CAAC,CAAC;SAClC;aAAM,IAAI,YAAY,IAAI,sBAAsB,EAAE;YACjD,EAAE,GAAG,sBAAsB,CAAC,YAAY,CAAC,CAAC;SAC3C;aAAM;YACL,EAAE,GAAG,aAAa,CAAC,YAAY,CAAC,CAAC;YACjC,IAAI,EAAE,IAAI,IAAI,EAAE;gBACd,MAAM,IAAI,UAAU,CAChB,WAAW,mBAAmB,KAAK,UAAU,IAAI;oBACjD,oDAAoD;oBACpD,UAAU,mBAAmB,kCAAkC;oBAC/D,iEAAiE;oBACjE,SAAS;oBACT,iBAAiB,mBAAmB,6BAA6B;oBACjE,sCAAsC;oBACtC,mCAAmC,CAAC,CAAC;gBACzC,0DAA0D;aAC3D;SACF;QACD,OAAO,EAAE,CAAC;KACX;SAAM;QACL,8DAA8D;QAC9D,MAAM,MAAM,GAAG,UAAU,CAAC;QAC1B,IAAI,MAAM,CAAC,WAAW,CAAC,IAAI,IAAI,IAAI,MAAM,CAAC,QAAQ,CAAC,IAAI,IAAI,EAAE;YAC3D,MAAM,IAAI,UAAU,CAChB,GAAG,mBAAmB,4BAA4B;gBAClD,GAAG,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,KAAK;gBAC9B,oCAAoC,CAAC,CAAC;SAC3C;QACD,MAAM,SAAS,GAAG,MAAM,CAAC,WAAW,CAAW,CAAC;QAChD,IAAI,GAAG,EAAE,UAAU,CAAC;QACpB,IAAI,SAAS,IAAI,aAAa,EAAE;YAC9B,CAAC,GAAG,EAAE,UAAU,CAAC,GAAG,aAAa,CAAC,SAAS,CAAC,CAAC;SAC9C;aAAM,IAAI,SAAS,IAAI,sBAAsB,EAAE;YAC9C,CAAC,GAAG,EAAE,UAAU,CAAC,GAAG,sBAAsB,CAAC,WAAW,CAAC,CAAC;SACzD;aAAM,IAAI,SAAS,IAAI,aAAa,EAAE;YACrC,CAAC,GAAG,EAAE,UAAU,CAAC,GAAG,aAAa,CAAC,SAAS,CAAC,CAAC;SAC9C;QACD,IAAI,GAAG,IAAI,IAAI,EAAE;YACf,MAAM,IAAI,UAAU,CAChB,WAAW,mBAAmB,KAAK,SAAS,IAAI;gBAChD,oDAAoD;gBACpD,UAAU,mBAAmB,kCAAkC;gBAC/D,iEAAiE;gBACjE,SAAS;gBACT,iBAAiB,mBAAmB,6BAA6B;gBACjE,sCAAsC;gBACtC,mCAAmC,CAAC,CAAC;YACzC,0DAA0D;SAC3D;QACD,IAAI,UAAU,IAAI,IAAI,EAAE;YACtB,uEAAuE;YACvE,wEAAwE;YACxE,0EAA0E;YAC1E,gBAAgB;YAEhB,kCAAkC;YAClC,MAAM,qBAAqB,GAAG,EAA8B,CAAC;YAC7D,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,IAAI,CAAC,sBAAsB,CAAC,EAAE;gBACrD,qBAAqB,CAAC,GAAG,CAAC,GAAG,sBAAsB,CAAC,GAAG,CAAC,CAAC;aAC1D;YACD,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,IAAI,CAAC,aAAa,CAAC,EAAE;gBAC5C,qBAAqB,CAAC,GAAG,CAAC,GAAG,aAAa,CAAC,GAAG,CAAC,CAAC;aACjD;YACD,kCAAkC;YAClC,MAAM,YAAY,GAAG,MAAM,CAAC,QAAQ,CAA6B,CAAC;YAClE,YAAY,CAAC,eAAe,CAAC,GAAG,qBAAqB,CAAC;YAEtD,MAAM,mBAAmB,qBAAO,sBAAsB,CAAC,CAAC;YACxD,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,IAAI,CAAC,aAAa,CAAC,EAAE;gBAC5C,sBAAsB,CAAC,GAAG,CAAC,GAAG,aAAa,CAAC,GAAG,CAAC,CAAC;aAClD;YACD,6BAA6B,CAAC,MAAM,CAAC,QAAQ,CAAC,CAAC,CAAC;YAChD,MAAM,SAAS,GACX,UAAU,CAAC,GAAG,EAAE,MAAM,CAAC,QAAQ,CAAC,EAAE,aAAa,EAAE,cAAc,CAAC,CAAC;YACrE,sBAAsB,qBAAO,mBAAmB,CAAC,CAAC;YAElD,OAAO,SAAS,CAAC;SAClB;aAAM;YACL,kDAAkD;YAClD,4CAA4C;YAC5C,8BAA8B;YAC9B,MAAM,mBAAmB,qBAAO,sBAAsB,CAAC,CAAC;YACxD,KAAK,MAAM,GAAG,IAAI,MAAM,CAAC,IAAI,CAAC,aAAa,CAAC,EAAE;gBAC5C,sBAAsB,CAAC,GAAG,CAAC,GAAG,aAAa,CAAC,GAAG,CAAC,CAAC;aAClD;YACD,mEAAmE;YACnE,iEAAiE;YACjE,oEAAoE;YACpE,MAAM,SAAS,GAAG,IAAI,GAAG,CAAC,MAAM,CAAC,QAAQ,CAAC,CAAC,CAAC;YAC5C,sBAAsB,qBAAO,mBAAmB,CAAC,CAAC;YAClD,OAAO,SAAS,CAAC;SAClB;KACF;AACH,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,aAAa,CAAC,CAAS,EAAE,CAAS;IAChD,OAAO,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;AAC1C,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,oBAAoB,CAAC,CAAS,EAAE,CAAS;IACvD,OAAO,CAAC,CAAC,GAAG,aAAa,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;AAClC,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,aAAa,CAAC,KAAa;IACzC,QAAQ,KAAK,EAAE;QACb,KAAK,SAAS;YACZ,OAAO,SAAS,CAAC;QACnB;YACE,MAAM,IAAI,UAAU,CAAC,kBAAkB,KAAK,EAAE,CAAC,CAAC;KACnD;AACH,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,YAAY,CAAC,EAAY,EAAE,EAAY;IACrD,IAAI,EAAE,IAAI,IAAI,IAAI,EAAE,IAAI,IAAI,EAAE;QAC5B,OAAO,EAAE,KAAK,EAAE,CAAC;KAClB;IACD,IAAI,EAAE,CAAC,MAAM,KAAK,EAAE,CAAC,MAAM,EAAE;QAC3B,OAAO,KAAK,CAAC;KACd;IACD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QAClC,IAAI,EAAE,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,EAAE;YACnB,OAAO,KAAK,CAAC;SACd;KACF;IACD,OAAO,IAAI,CAAC;AACd,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,MAAM,CAAI,EAAO;IAC/B,IAAI,EAAE,IAAI,IAAI,EAAE;QACd,OAAO,EAAE,CAAC;KACX;IACD,MAAM,GAAG,GAAQ,EAAE,CAAC;IACpB,oDAAoD;IACpD,KAAK,MAAM,CAAC,IAAI,EAAE,EAAE;QAClB,IAAI,GAAG,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,EAAE;YACzB,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;SACb;KACF;IACD,OAAO,GAAG,CAAC;AACb,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,aAAa,CAAC,GAAO;IACnC,IAAI,GAAG,IAAI,IAAI,EAAE;QACf,MAAM,IAAI,UAAU,CAAC,yBAAyB,IAAI,CAAC,SAAS,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;KACtE;IACD,KAAK,MAAM,GAAG,IAAI,GAAG,EAAE;QACrB,IAAI,GAAG,CAAC,cAAc,CAAC,GAAG,CAAC,EAAE;YAC3B,OAAO,KAAK,CAAC;SACd;KACF;IACD,OAAO,IAAI,CAAC;AACd,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,yBAAyB,CACrC,MAAgB,EAAE,KAAa,EAAE,KAAa;IAChD,IAAI,KAAK,IAAI,IAAI,EAAE;QACjB,OAAO;KACR;IACD,IAAI,MAAM,CAAC,OAAO,CAAC,KAAK,CAAC,GAAG,CAAC,EAAE;QAC7B,MAAM,IAAI,UAAU,CAAC,GAAG,KAAK,mBAAmB,KAAK,uBACjD,MAAM,qBAAqB,CAAC,CAAC;KAClC;AACH,CAAC;AAED;;;;;;;;;;;;;GAaG;AACH,wBAAwB;AACxB,MAAM,UAAU,uBAAuB,CACnC,CAAM,EAAE,YAAoB,EAAE,SAAS,GAAG,CAAC,EAC3C,SAAS,GAAG,QAAQ;IACtB,MAAM,CAAC,SAAS,IAAI,CAAC,CAAC,CAAC;IACvB,MAAM,CAAC,SAAS,IAAI,SAAS,CAAC,CAAC;IAC/B,OAAO,CACH,KAAK,CAAC,OAAO,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,MAAM,IAAI,SAAS,IAAI,CAAC,CAAC,MAAM,IAAI,SAAS;QAClE,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,OAAO,CAAC,KAAK,YAAY,CAAC,CAAC,CAAC;AAC/C,CAAC;AACD,uBAAuB;AAEvB;;;;;;GAMG;AACH,MAAM,UAAU,qBAAqB,CAAC,KAAsB,EAAE,IAAY;IACxE,IAAI,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;QACxB,IAAI,CAAC,MAAM,CACP,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,GAAG,EAAE,CAAC,GAAG,IAAI,kCAAkC,CAAC,CAAC;QACvE,KAAK,CAAC,OAAO,CACT,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,qBAAqB,CAAC,CAAC,EAAE,WAAW,CAAC,GAAG,CAAC,OAAO,IAAI,EAAE,CAAC,CAAC,CAAC;KACxE;SAAM;QACL,IAAI,CAAC,MAAM,CACP,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,IAAI,KAAK,GAAG,CAAC,EACpC,GAAG,EAAE,CAAC,YAAY,IAAI,qCAAqC;YACvD,GAAG,sBAAsB,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC;KAC9C;AACH,CAAC;AAED;;;;;;;;;GASG;AACH,kCAAkC;AAClC,MAAM,UAAU,sBAAsB,CAAC,KAAU;IAC/C,IAAI,KAAK,KAAK,IAAI,EAAE;QAClB,OAAO,MAAM,CAAC;KACf;SAAM,IAAI,KAAK,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE;QAC/B,OAAO,GAAG,GAAG,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,sBAAsB,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,GAAG,GAAG,CAAC;KACxE;SAAM,IAAI,OAAO,KAAK,KAAK,QAAQ,EAAE;QACpC,OAAO,IAAI,KAAK,GAAG,CAAC;KACrB;SAAM;QACL,OAAO,GAAG,KAAK,EAAE,CAAC;KACnB;AACH,CAAC;AAED;;;;;;;;GAQG;AACH,MAAM,UAAU,QAAQ,CACpB,CAA4B,EAAE,MAAc,EAC5C,OAAkB;IACpB,IAAI,QAAQ,GAAG,OAAO,IAAI,IAAI,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC;IACxD,IAAI,UAAa,CAAC;IAClB,MAAM,EAAE,GAAG,CAAC,GAAG,IAAe,EAAE,EAAE;QAChC,MAAM,GAAG,GAAG,OAAO,IAAI,IAAI,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC;QACrD,IAAI,GAAG,GAAG,QAAQ,GAAG,MAAM,EAAE;YAC3B,OAAO,UAAU,CAAC;SACnB;QACD,QAAQ,GAAG,GAAG,CAAC;QACf,UAAU,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,CAAC;QACxB,OAAO,UAAU,CAAC;IACpB,CAAC,CAAC;IACF,OAAO,EAAE,CAAC;AACZ,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,0BAA0B,CAAC,cAAsB;IAE/D,IAAI,cAAc,KAAK,MAAM,EAAE;QAC7B,OAAO,MAAM,CAAC;KACf;IACD,IAAI,cAAc,KAAK,QAAQ,EAAE;QAC/B,OAAO,QAAQ,CAAC;KACjB;IACD,IAAI,cAAc,KAAK,KAAK,EAAE;QAC5B,OAAO,KAAK,CAAC;KACd;IACD,OAAO,IAAI,CAAC;AACd,CAAC;AAID;;;;;;;;;;;;;;GAcG;AACH,MAAM,UAAU,2BAA2B,CAAC,GAAG,aAA6B;IAE1E,MAAM,CAAC,aAAa,CAAC,MAAM,GAAG,CAAC,EAAE,wBAAwB,CAAC,CAAC;IAE3D,KAAK,MAAM,MAAM,IAAI,aAAa,EAAE;QAClC,MAAM,CAAC,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,mCAAmC,CAAC,CAAC;QACnE,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,EAAE,4BAA4B,CAAC,CAAC;KACzD;IAED,OAAO,aAAa,CAAC,MAAM,CAAC,CAAC,QAAQ,EAAE,MAAM,EAAE,EAAE;QAC/C,IAAI,QAAQ,CAAC,MAAM,KAAK,CAAC,EAAE;YACzB,OAAO,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC;SACrC;QAED,OAAO,MAAM;aACR,GAAG,CAAC,KAAK,CAAC,EAAE;YACX,OAAO,QAAQ,CAAC,GAAG,CAAC,CAAC,SAAS,EAAE,EAAE,CAAC,CAAC,GAAG,SAAS,EAAE,KAAK,CAAC,CAAC,CAAC;QAC5D,CAAC,CAAC;aACD,MAAM,CAAC,CAAC,gBAAgB,EAAE,kBAAkB,EAAE,EAAE;YAC/C,OAAO,gBAAgB,CAAC,MAAM,CAAC,kBAAkB,CAAC,CAAC;QACrD,CAAC,EAAE,EAAE,CAAC,CAAC;IACb,CAAC,EAAE,EAAoB,CAAC,CAAC;AAC3B,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n/* Original source: utils/generic_utils.py */\n\nimport {DataType, fused, serialization, util} from '@tensorflow/tfjs-core';\n\nimport {AssertionError, ValueError} from '../errors';\n\n// tslint:enable\n\n/**\n * If `value` is an Array, equivalent to Python's `value * numValues`.\n * If `value` is not an Array, equivalent to Python's `[value] * numValues`\n */\n// tslint:disable-next-line:no-any\nexport function pyListRepeat(value: any, numValues: number): any[] {\n  if (Array.isArray(value)) {\n    // tslint:disable-next-line:no-any\n    let newArray: any[] = [];\n    for (let i = 0; i < numValues; i++) {\n      newArray = newArray.concat(value);\n    }\n    return newArray;\n  } else {\n    const newArray = new Array(numValues);\n    newArray.fill(value);\n    return newArray;\n  }\n}\n\nexport function assert(val: boolean, message?: string): void {\n  if (!val) {\n    throw new AssertionError(message);\n  }\n}\n\n/**\n * Count the number of elements of the `array` that are equal to `reference`.\n */\nexport function count<T>(array: T[], refernce: T) {\n  let counter = 0;\n  for (const item of array) {\n    if (item === refernce) {\n      counter++;\n    }\n  }\n  return counter;\n}\n\n/**\n * If an array is of length 1, just return the first element. Otherwise, return\n * the full array.\n * @param tensors\n */\nexport function singletonOrArray<T>(xs: T[]): T|T[] {\n  if (xs.length === 1) {\n    return xs[0];\n  }\n  return xs;\n}\n\n/**\n * Normalizes a list/tensor into a list.\n *\n * If a tensor is passed, we return\n * a list of size 1 containing the tensor.\n *\n * @param x target object to be normalized.\n */\n// tslint:disable-next-line:no-any\nexport function toList<T>(x: T|T[]): T[] {\n  if (Array.isArray(x)) {\n    return x;\n  }\n  return [x];\n}\n\n/**\n * Generate a UID for a list\n */\n// tslint:disable-next-line:no-any\nexport function objectListUid(objs: any|any[]): string {\n  const objectList = toList(objs);\n  let retVal = '';\n  for (const obj of objectList) {\n    if (obj.id == null) {\n      throw new ValueError(\n          `Object ${obj} passed to objectListUid without an id`);\n    }\n    if (retVal !== '') {\n      retVal = retVal + ', ';\n    }\n    retVal = `${retVal}${Math.abs(obj.id)}`;\n  }\n  return retVal;\n}\n/**\n * Converts string to snake-case.\n * @param name\n */\nexport function toSnakeCase(name: string): string {\n  const intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, '$1_$2');\n  const insecure =\n      intermediate.replace(/([a-z])([A-Z])/g, '$1_$2').toLowerCase();\n  /*\n   If the class is private the name starts with \"_\" which is not secure\n   for creating scopes. We prefix the name with \"private\" in this case.\n   */\n  if (insecure[0] !== '_') {\n    return insecure;\n  }\n  return 'private' + insecure;\n}\n\nexport function toCamelCase(identifier: string): string {\n  // quick return for empty string or single character strings\n  if (identifier.length <= 1) {\n    return identifier;\n  }\n  // Check for the underscore indicating snake_case\n  if (identifier.indexOf('_') === -1) {\n    return identifier;\n  }\n  return identifier.replace(/[_]+(\\w|$)/g, (m, p1) => p1.toUpperCase());\n}\n\n// tslint:disable-next-line:no-any\nlet _GLOBAL_CUSTOM_OBJECTS = {} as {[objName: string]: any};\n\nexport function serializeKerasObject(instance: serialization.Serializable):\n    serialization.ConfigDictValue {\n  if (instance === null || instance === undefined) {\n    return null;\n  }\n  const dict: serialization.ConfigDictValue = {};\n  dict['className'] = instance.getClassName();\n  dict['config'] = instance.getConfig();\n  return dict;\n}\n\n/**\n * Replace ndarray-style scalar objects in serialization objects with numbers.\n *\n * Background: In some versions of tf.keras, certain scalar values in the HDF5\n * model save file can be serialized as: `{'type': 'ndarray', 'value': num}`,\n * where in `num` is a plain number. This method converts such serialization\n * to a `number`.\n *\n * @param config The keras-format serialization object to be processed\n *   (in place).\n */\nfunction convertNDArrayScalarsInConfig(config: serialization.ConfigDictValue):\n    void {\n  if (config == null || typeof config !== 'object') {\n    return;\n  } else if (Array.isArray(config)) {\n    config.forEach(configItem => convertNDArrayScalarsInConfig(configItem));\n  } else {\n    const fields = Object.keys(config);\n    for (const field of fields) {\n      const value = config[field];\n      if (value != null && typeof value === 'object') {\n        if (!Array.isArray(value) && value['type'] === 'ndarray' &&\n            typeof value['value'] === 'number') {\n          config[field] = value['value'];\n        } else {\n          convertNDArrayScalarsInConfig(value as serialization.ConfigDict);\n        }\n      }\n    }\n  }\n}\n\n/**\n * Deserialize a saved Keras Object\n * @param identifier either a string ID or a saved Keras dictionary\n * @param moduleObjects a list of Python class names to object constructors\n * @param customObjects a list of Python class names to object constructors\n * @param printableModuleName debug text for the object being reconstituted\n * @param fastWeightInit Optional flag to use fast weight initialization\n *   during deserialization. This is applicable to cases in which\n *   the initialization will be immediately overwritten by loaded weight\n *   values. Default: `false`.\n * @returns a TensorFlow.js Layers object\n */\n// tslint:disable:no-any\nexport function deserializeKerasObject(\n    identifier: string|serialization.ConfigDict,\n    moduleObjects = {} as {[objName: string]: any},\n    customObjects = {} as {[objName: string]: any},\n    printableModuleName = 'object', fastWeightInit = false): any {\n  // tslint:enable\n  if (typeof identifier === 'string') {\n    const functionName = identifier;\n    let fn;\n    if (functionName in customObjects) {\n      fn = customObjects[functionName];\n    } else if (functionName in _GLOBAL_CUSTOM_OBJECTS) {\n      fn = _GLOBAL_CUSTOM_OBJECTS[functionName];\n    } else {\n      fn = moduleObjects[functionName];\n      if (fn == null) {\n        throw new ValueError(\n            `Unknown ${printableModuleName}: ${identifier}. ` +\n            `This may be due to one of the following reasons:\\n` +\n            `1. The ${printableModuleName} is defined in Python, in which ` +\n            `case it needs to be ported to TensorFlow.js or your JavaScript ` +\n            `code.\\n` +\n            `2. The custom ${printableModuleName} is defined in JavaScript, ` +\n            `but is not registered properly with ` +\n            `tf.serialization.registerClass().`);\n        // TODO(cais): Add link to tutorial page on custom layers.\n      }\n    }\n    return fn;\n  } else {\n    // In this case we are dealing with a Keras config dictionary.\n    const config = identifier;\n    if (config['className'] == null || config['config'] == null) {\n      throw new ValueError(\n          `${printableModuleName}: Improper config format: ` +\n          `${JSON.stringify(config)}.\\n` +\n          `'className' and 'config' must set.`);\n    }\n    const className = config['className'] as string;\n    let cls, fromConfig;\n    if (className in customObjects) {\n      [cls, fromConfig] = customObjects[className];\n    } else if (className in _GLOBAL_CUSTOM_OBJECTS) {\n      [cls, fromConfig] = _GLOBAL_CUSTOM_OBJECTS['className'];\n    } else if (className in moduleObjects) {\n      [cls, fromConfig] = moduleObjects[className];\n    }\n    if (cls == null) {\n      throw new ValueError(\n          `Unknown ${printableModuleName}: ${className}. ` +\n          `This may be due to one of the following reasons:\\n` +\n          `1. The ${printableModuleName} is defined in Python, in which ` +\n          `case it needs to be ported to TensorFlow.js or your JavaScript ` +\n          `code.\\n` +\n          `2. The custom ${printableModuleName} is defined in JavaScript, ` +\n          `but is not registered properly with ` +\n          `tf.serialization.registerClass().`);\n      // TODO(cais): Add link to tutorial page on custom layers.\n    }\n    if (fromConfig != null) {\n      // Porting notes: Instead of checking to see whether fromConfig accepts\n      // customObjects, we create a customObjects dictionary and tack it on to\n      // config['config'] as config['config'].customObjects. Objects can use it,\n      // if they want.\n\n      // tslint:disable-next-line:no-any\n      const customObjectsCombined = {} as {[objName: string]: any};\n      for (const key of Object.keys(_GLOBAL_CUSTOM_OBJECTS)) {\n        customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key];\n      }\n      for (const key of Object.keys(customObjects)) {\n        customObjectsCombined[key] = customObjects[key];\n      }\n      // Add the customObjects to config\n      const nestedConfig = config['config'] as serialization.ConfigDict;\n      nestedConfig['customObjects'] = customObjectsCombined;\n\n      const backupCustomObjects = {..._GLOBAL_CUSTOM_OBJECTS};\n      for (const key of Object.keys(customObjects)) {\n        _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];\n      }\n      convertNDArrayScalarsInConfig(config['config']);\n      const returnObj =\n          fromConfig(cls, config['config'], customObjects, fastWeightInit);\n      _GLOBAL_CUSTOM_OBJECTS = {...backupCustomObjects};\n\n      return returnObj;\n    } else {\n      // Then `cls` may be a function returning a class.\n      // In this case by convention `config` holds\n      // the kwargs of the function.\n      const backupCustomObjects = {..._GLOBAL_CUSTOM_OBJECTS};\n      for (const key of Object.keys(customObjects)) {\n        _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];\n      }\n      // In python this is **config['config'], for tfjs-layers we require\n      // classes that use this fall-through construction method to take\n      // a config interface that mimics the expansion of named parameters.\n      const returnObj = new cls(config['config']);\n      _GLOBAL_CUSTOM_OBJECTS = {...backupCustomObjects};\n      return returnObj;\n    }\n  }\n}\n\n/**\n * Compares two numbers for sorting.\n * @param a\n * @param b\n */\nexport function numberCompare(a: number, b: number) {\n  return (a < b) ? -1 : ((a > b) ? 1 : 0);\n}\n\n/**\n * Comparison of two numbers for reverse sorting.\n * @param a\n * @param b\n */\nexport function reverseNumberCompare(a: number, b: number) {\n  return -1 * numberCompare(a, b);\n}\n\n/**\n * Convert a string into the corresponding DType.\n * @param dtype\n * @returns An instance of DType.\n */\nexport function stringToDType(dtype: string): DataType {\n  switch (dtype) {\n    case 'float32':\n      return 'float32';\n    default:\n      throw new ValueError(`Invalid dtype: ${dtype}`);\n  }\n}\n\n/**\n * Test the element-by-element equality of two Arrays of strings.\n * @param xs First array of strings.\n * @param ys Second array of strings.\n * @returns Wether the two arrays are all equal, element by element.\n */\nexport function stringsEqual(xs: string[], ys: string[]): boolean {\n  if (xs == null || ys == null) {\n    return xs === ys;\n  }\n  if (xs.length !== ys.length) {\n    return false;\n  }\n  for (let i = 0; i < xs.length; ++i) {\n    if (xs[i] !== ys[i]) {\n      return false;\n    }\n  }\n  return true;\n}\n\n/**\n * Get the unique elements of an array.\n * @param xs Array.\n * @returns An Array consisting of the unique elements in `xs`.\n */\nexport function unique<T>(xs: T[]): T[] {\n  if (xs == null) {\n    return xs;\n  }\n  const out: T[] = [];\n  // TODO(cais): Maybe improve performance by sorting.\n  for (const x of xs) {\n    if (out.indexOf(x) === -1) {\n      out.push(x);\n    }\n  }\n  return out;\n}\n\n/**\n * Determine if an Object is empty (i.e., does not have own properties).\n * @param obj Object\n * @returns Whether the Object is empty.\n * @throws ValueError: If object is `null` or `undefined`.\n */\nexport function isObjectEmpty(obj: {}): boolean {\n  if (obj == null) {\n    throw new ValueError(`Invalid value in obj: ${JSON.stringify(obj)}`);\n  }\n  for (const key in obj) {\n    if (obj.hasOwnProperty(key)) {\n      return false;\n    }\n  }\n  return true;\n}\n\n/**\n * Helper function used to build type union/enum run-time checkers.\n * @param values The list of allowed values.\n * @param label A string name for the type\n * @param value The value to test.\n * @throws ValueError: If the value is not in values nor `undefined`/`null`.\n */\nexport function checkStringTypeUnionValue(\n    values: string[], label: string, value: string): void {\n  if (value == null) {\n    return;\n  }\n  if (values.indexOf(value) < 0) {\n    throw new ValueError(`${value} is not a valid ${label}.  Valid values are ${\n        values} or null/undefined.`);\n  }\n}\n\n/**\n * Helper function for verifying the types of inputs.\n *\n * Ensures that the elements of `x` are all of type `expectedType`.\n * Also verifies that the length of `x` is within bounds.\n *\n * @param x Object to test.\n * @param expectedType The string expected type of all of the elements in the\n * Array.\n * @param minLength Return false if x.length is less than this.\n * @param maxLength Return false if x.length is greater than this.\n * @returns true if and only if `x` is an `Array<expectedType>` with\n * length >= `minLength` and <= `maxLength`.\n */\n// tslint:disable:no-any\nexport function checkArrayTypeAndLength(\n    x: any, expectedType: string, minLength = 0,\n    maxLength = Infinity): boolean {\n  assert(minLength >= 0);\n  assert(maxLength >= minLength);\n  return (\n      Array.isArray(x) && x.length >= minLength && x.length <= maxLength &&\n      x.every(e => typeof e === expectedType));\n}\n// tslint:enable:no-any\n\n/**\n * Assert that a value or an array of value are positive integer.\n *\n * @param value The value being asserted on. May be a single number or an array\n *   of numbers.\n * @param name Name of the value, used to make the error message.\n */\nexport function assertPositiveInteger(value: number|number[], name: string) {\n  if (Array.isArray(value)) {\n    util.assert(\n        value.length > 0, () => `${name} is unexpectedly an empty array.`);\n    value.forEach(\n        (v, i) => assertPositiveInteger(v, `element ${i + 1} of ${name}`));\n  } else {\n    util.assert(\n        Number.isInteger(value) && value > 0,\n        () => `Expected ${name} to be a positive integer, but got ` +\n            `${formatAsFriendlyString(value)}.`);\n  }\n}\n\n/**\n * Format a value into a display-friendly, human-readable fashion.\n *\n * - `null` is formatted as `'null'`\n * - Strings are formated with flanking pair of quotes.\n * - Arrays are formatted with flanking pair of square brackets.\n *\n * @param value The value to display.\n * @return Formatted string.\n */\n// tslint:disable-next-line:no-any\nexport function formatAsFriendlyString(value: any): string {\n  if (value === null) {\n    return 'null';\n  } else if (Array.isArray(value)) {\n    return '[' + value.map(v => formatAsFriendlyString(v)).join(',') + ']';\n  } else if (typeof value === 'string') {\n    return `\"${value}\"`;\n  } else {\n    return `${value}`;\n  }\n}\n\n/**\n * Returns a function `f2` (decorator) which wraps the original function\n * `f`. `f2` guarantees that `f` can be called at most once\n * every `waitMs` ms. If `f2` is called more often, it will return\n * the last returned result of `f`.\n *\n * @param f The original function `f` to wrap.\n * @param waitMs The time between two consecutive calls to `f` in ms.\n */\nexport function debounce<T>(\n    f: (...args: Array<{}>) => T, waitMs: number,\n    nowFunc?: Function): (...args: Array<{}>) => T {\n  let lastTime = nowFunc != null ? nowFunc() : util.now();\n  let lastResult: T;\n  const f2 = (...args: Array<{}>) => {\n    const now = nowFunc != null ? nowFunc() : util.now();\n    if (now - lastTime < waitMs) {\n      return lastResult;\n    }\n    lastTime = now;\n    lastResult = f(...args);\n    return lastResult;\n  };\n  return f2;\n}\n\n/**\n * Returns the fusable activation given a layers identifier.\n *\n * @param activationName The layers identifier string.\n * @return The name of the fusable activation.\n */\nexport function mapActivationToFusedKernel(activationName: string):\n    fused.Activation {\n  if (activationName === 'relu') {\n    return 'relu';\n  }\n  if (activationName === 'linear') {\n    return 'linear';\n  }\n  if (activationName === 'elu') {\n    return 'elu';\n  }\n  return null;\n}\n\ntype PossibleValues = Array<Array<boolean|string|number>>;\n\n/**\n * Returns the cartesian product of sets of values.\n * This works the same as itertools.product in Python.\n *\n * Example:\n *\n * filters = [128, 256, 512]\n * paddings = ['same', 'valid']\n *\n * product = [ [128, 'same'], [128, 'valid'], [256, 'same'], [256, 'valid'],\n * [512, 'same'], [512, 'valid']]\n *\n * @param arrayOfValues List/array of values.\n * @return The cartesian product.\n */\nexport function getCartesianProductOfValues(...arrayOfValues: PossibleValues):\n    PossibleValues {\n  assert(arrayOfValues.length > 0, 'arrayOfValues is empty');\n\n  for (const values of arrayOfValues) {\n    assert(Array.isArray(values), 'one of the values is not an array');\n    assert(values.length > 0, 'one of the values is empty');\n  }\n\n  return arrayOfValues.reduce((products, values) => {\n    if (products.length === 0) {\n      return values.map(value => [value]);\n    }\n\n    return values\n        .map(value => {\n          return products.map((prevValue) => [...prevValue, value]);\n        })\n        .reduce((flattenedProduct, unflattenedProduct) => {\n          return flattenedProduct.concat(unflattenedProduct);\n        }, []);\n  }, [] as PossibleValues);\n}\n"]}
/** @license See the LICENSE file. */
/// <amd-module name="@tensorflow/tfjs-layers/dist/version" />
declare const version = "4.10.0";
declare const version = "4.11.0";
export { version };
/** @license See the LICENSE file. */
// This code is auto-generated, do not modify this file!
const version = '4.10.0';
const version = '4.11.0';
export { version };
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoidmVyc2lvbi5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uL3RmanMtbGF5ZXJzL3NyYy92ZXJzaW9uLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBLHFDQUFxQztBQUVyQyx3REFBd0Q7QUFDeEQsTUFBTSxPQUFPLEdBQUcsUUFBUSxDQUFDO0FBQ3pCLE9BQU8sRUFBQyxPQUFPLEVBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKiBAbGljZW5zZSBTZWUgdGhlIExJQ0VOU0UgZmlsZS4gKi9cblxuLy8gVGhpcyBjb2RlIGlzIGF1dG8tZ2VuZXJhdGVkLCBkbyBub3QgbW9kaWZ5IHRoaXMgZmlsZSFcbmNvbnN0IHZlcnNpb24gPSAnNC4xMC4wJztcbmV4cG9ydCB7dmVyc2lvbn07XG4iXX0=
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoidmVyc2lvbi5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uL3RmanMtbGF5ZXJzL3NyYy92ZXJzaW9uLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBLHFDQUFxQztBQUVyQyx3REFBd0Q7QUFDeEQsTUFBTSxPQUFPLEdBQUcsUUFBUSxDQUFDO0FBQ3pCLE9BQU8sRUFBQyxPQUFPLEVBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKiBAbGljZW5zZSBTZWUgdGhlIExJQ0VOU0UgZmlsZS4gKi9cblxuLy8gVGhpcyBjb2RlIGlzIGF1dG8tZ2VuZXJhdGVkLCBkbyBub3QgbW9kaWZ5IHRoaXMgZmlsZSFcbmNvbnN0IHZlcnNpb24gPSAnNC4xMS4wJztcbmV4cG9ydCB7dmVyc2lvbn07XG4iXX0=
{
"name": "@tensorflow/tfjs-layers",
"version": "4.10.0",
"version": "4.11.0",
"description": "TensorFlow layers API in JavaScript",

@@ -41,4 +41,4 @@ "license": "Apache-2.0 AND MIT",

"peerDependencies": {
"@tensorflow/tfjs-core": "4.10.0"
"@tensorflow/tfjs-core": "4.11.0"
}
}

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is not supported yet

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap
  • Changelog

Packages

npm

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc