@tensorflow/tfjs-layers
Advanced tools
Comparing version 4.18.0 to 4.19.0-rc.0
@@ -155,7 +155,7 @@ /** | ||
/** | ||
* Swish activation function | ||
* Gelu activation function | ||
*/ | ||
export declare class Swish extends Activation { | ||
export declare class Gelu extends Activation { | ||
/** @nocollapse */ | ||
static readonly className = "swish"; | ||
static readonly className = "gelu"; | ||
/** | ||
@@ -165,8 +165,21 @@ * Calculate the activation function. | ||
* @param x Tensor. | ||
* @param alpha Scaling factor for the sigmoid function. | ||
* @returns a Tensor of the same shape as x | ||
*/ | ||
apply(x: Tensor, alpha?: number): Tensor; | ||
apply(x: Tensor): Tensor; | ||
} | ||
/** | ||
* GeluNew activation function | ||
*/ | ||
export declare class GeluNew extends Activation { | ||
/** @nocollapse */ | ||
static readonly className = "gelu_new"; | ||
/** | ||
* Calculate the activation function. | ||
* | ||
* @param x Tensor. | ||
* @returns a Tensor of the same shape as x | ||
*/ | ||
apply(x: Tensor): Tensor; | ||
} | ||
/** | ||
* Mish activation function | ||
@@ -185,4 +198,19 @@ */ | ||
} | ||
/** | ||
* Swish activation function | ||
*/ | ||
export declare class Swish extends Activation { | ||
/** @nocollapse */ | ||
static readonly className = "swish"; | ||
/** | ||
* Calculate the activation function. | ||
* | ||
* @param x Tensor. | ||
* @param alpha Scaling factor for the sigmoid function. | ||
* @returns a Tensor of the same shape as x | ||
*/ | ||
apply(x: Tensor, alpha?: number): Tensor; | ||
} | ||
export declare function serializeActivation(activation: Activation): string; | ||
export declare function deserializeActivation(config: serialization.ConfigDict, customObjects?: serialization.ConfigDict): Activation; | ||
export declare function getActivation(identifier: ActivationIdentifier | serialization.ConfigDict | Activation): Activation; |
@@ -207,5 +207,5 @@ /** | ||
/** | ||
* Swish activation function | ||
* Gelu activation function | ||
*/ | ||
class Swish extends Activation { | ||
class Gelu extends Activation { | ||
/** | ||
@@ -215,14 +215,41 @@ * Calculate the activation function. | ||
* @param x Tensor. | ||
* @param alpha Scaling factor for the sigmoid function. | ||
* @returns a Tensor of the same shape as x | ||
*/ | ||
apply(x, alpha = 1) { | ||
return tidy(() => tfc.mul(tfc.sigmoid(tfc.mul(x, alpha)), x)); | ||
apply(x) { | ||
return tidy(() => { | ||
return tfc.tidy(() => { | ||
const sqrtTwo = Math.sqrt(2); | ||
// Compute Φ(x) using the erf function | ||
const cdf = tfc.mul(0.5, tfc.add(1, tfc.erf(tfc.div(x, sqrtTwo)))); | ||
// Compute GELU(x) = x * Φ(x) | ||
return tfc.mul(x, cdf); | ||
}); | ||
}); | ||
} | ||
} | ||
/** @nocollapse */ | ||
Swish.className = 'swish'; | ||
export { Swish }; | ||
serialization.registerClass(Swish); | ||
Gelu.className = 'gelu'; | ||
export { Gelu }; | ||
serialization.registerClass(Gelu); | ||
/** | ||
* GeluNew activation function | ||
*/ | ||
class GeluNew extends Activation { | ||
/** | ||
* Calculate the activation function. | ||
* | ||
* @param x Tensor. | ||
* @returns a Tensor of the same shape as x | ||
*/ | ||
apply(x) { | ||
return tidy(() => { | ||
return tfc.mul(0.5, tfc.mul(x, tfc.add(1, tfc.tanh(tfc.mul(tfc.sqrt(tfc.div(2, Math.PI)), tfc.add(x, tfc.mul(0.044715, tfc.pow(x, 3)))))))); | ||
}); | ||
} | ||
} | ||
/** @nocollapse */ | ||
GeluNew.className = 'gelu_new'; | ||
export { GeluNew }; | ||
serialization.registerClass(GeluNew); | ||
/** | ||
* Mish activation function | ||
@@ -245,2 +272,21 @@ */ | ||
serialization.registerClass(Mish); | ||
/** | ||
* Swish activation function | ||
*/ | ||
class Swish extends Activation { | ||
/** | ||
* Calculate the activation function. | ||
* | ||
* @param x Tensor. | ||
* @param alpha Scaling factor for the sigmoid function. | ||
* @returns a Tensor of the same shape as x | ||
*/ | ||
apply(x, alpha = 1) { | ||
return tidy(() => tfc.mul(tfc.sigmoid(tfc.mul(x, alpha)), x)); | ||
} | ||
} | ||
/** @nocollapse */ | ||
Swish.className = 'swish'; | ||
export { Swish }; | ||
serialization.registerClass(Swish); | ||
export function serializeActivation(activation) { | ||
@@ -272,2 +318,2 @@ return activation.getClassName(); | ||
} | ||
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"activations.js","sourceRoot":"","sources":["../../../../../tfjs-layers/src/activations.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,6BAA6B;AAC7B,OAAO,KAAK,GAAG,MAAM,uBAAuB,CAAC;AAC7C,OAAO,EAAC,aAAa,EAAU,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAClE,OAAO,KAAK,CAAC,MAAM,wBAAwB,CAAC;AAE5C,OAAO,EAAC,sBAAsB,EAAC,MAAM,uBAAuB,CAAC;AAE7D;;;;;;GAMG;AACH,MAAM,OAAgB,UAAW,SAAQ,aAAa,CAAC,YAAY;IAEjE,SAAS;QACP,OAAO,EAAE,CAAC;IACZ,CAAC;CACF;AAED;;;GAGG;AACH,MAAa,GAAI,SAAQ,UAAU;IAGjC;;;;;;OAMG;IACH,KAAK,CAAC,CAAS,EAAE,KAAK,GAAG,CAAC;QACxB,OAAO,CAAC,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,CAAC;IACzB,CAAC;;AAXD,kBAAkB;AACF,aAAS,GAAG,KAAK,CAAC;SAFvB,GAAG;AAchB,aAAa,CAAC,aAAa,CAAC,GAAG,CAAC,CAAC;AAEjC;;;;;;GAMG;AACH,MAAa,IAAK,SAAQ,UAAU;IAGlC,KAAK,CAAC,CAAS;QACb,OAAO,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACrB,CAAC;;AAJD,kBAAkB;AACF,cAAS,GAAG,MAAM,CAAC;SAFxB,IAAI;AAOjB,aAAa,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;AAElC;;GAEG;AACH,MAAa,IAAK,SAAQ,UAAU;IAGlC,KAAK,CAAC,CAAS;QACb,OAAO,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACrB,CAAC;;AAJD,kBAAkB;AACF,cAAS,GAAG,MAAM,CAAC;SAFxB,IAAI;AAOjB,aAAa,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;AAElC;;GAEG;AACH,MAAa,KAAM,SAAQ,UAAU;IAGnC,KAAK,CAAC,CAAS;QACb,OAAO,IAAI,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,OAAO,CAAC,GAAG,EAAE,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACnD,CAAC;;AAJD,kBAAkB;AACF,eAAS,GAAG,OAAO,CAAC;SAFzB,KAAK;AAOlB,aAAa,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC;AAEnC,gCAAgC;AAChC,MAAa,MAAO,SAAQ,UAAU;IAGpC,KAAK,CAAC,CAAS;QACb,OAAO,CAAC,CAAC;IACX,CAAC;;AAJD,kBAAkB;AACF,gBAAS,GAAG,QAAQ,CAAC;SAF1B,MAAM;AAOnB,aAAa,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;AAEpC;;GAEG;AACH,MAAa,OAAQ,SAAQ,UAAU;IAGrC,KAAK,CAAC,CAAS;QACb,OAAO,GAAG,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IACxB,CAAC;;AAJD,kBAAkB;AACF,iBAAS,GAAG,SAAS,CAAC;SAF3B,OAAO;AAOpB,aAAa,CAAC,aAAa,CAAC,OAAO,CAAC,CAAC;AAErC;;GAEG;AACH,MAAa,WAAY,SAAQ,UAAU;IAGzC,KAAK,CAAC,CAAS;QACb,OAAO,CAAC,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC;IAC1B,CAAC;;AAJD,kBAAkB;AACF,qBAAS,GAAG,aAAa,CAAC;SAF/B,WAAW;AAOxB,aAAa,CAAC,aAAa,CAAC,WAAW,CAAC,CAAC;AAEzC;;GAEG;AACH,MAAa,QAAS,SAAQ,UAAU;IAGtC,KAAK,CAAC,CAAS;QACb,OAAO,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;IACzB,CAAC;;AAJD,kBAAkB;AACF,kBAAS,GAAG,UAAU,CAAC;SAF5B,QAAQ;AAOrB,aAAa,CAAC,aAAa,CAAC,QAAQ,CAAC,CAAC;AAEtC;;GAEG;AACH,MAAa,QAAS,SAAQ,UAAU;IAGtC,KAAK,CAAC,CAAS;QACb,OAAO,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;IACvB,CAAC;;AAJD,kBAAkB;AACF,kBAAS,GAAG,UAAU,CAAC;SAF5B,QAAQ;AAOrB,aAAa,CAAC,aAAa,CAAC,QAAQ,CAAC,CAAC;AAEtC;;GAEG;AACH,MAAa,IAAK,SAAQ,UAAU;IAGlC,KAAK,CAAC,CAAS;QACb,OAAO,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACrB,CAAC;;AAJD,kBAAkB;AACF,cAAS,GAAG,MAAM,CAAC;SAFxB,IAAI;AAOjB,aAAa,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;AAElC;;GAEG;AACH,MAAa,OAAQ,SAAQ,UAAU;IAGrC;;;;;;;;;;;OAWG;IACH,KAAK,CAAC,CAAS,EAAE,OAAe,CAAC,CAAC,CAAC,CAAC;QAClC,OAAO,GAAG,CAAC,OAAO,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC;IAC9B,CAAC;;AAhBD,kBAAkB;AACF,iBAAS,GAAG,SAAS,CAAC;SAF3B,OAAO;AAmBpB,aAAa,CAAC,aAAa,CAAC,OAAO,CAAC,CAAC;AAErC;;GAEG;AACH,MAAa,UAAW,SAAQ,UAAU;IAGxC;;;;;;;;;;;;OAYG;IACH,KAAK,CAAC,CAAS,EAAE,OAAe,CAAC,CAAC,CAAC,CAAC;QAClC,OAAO,GAAG,CAAC,UAAU,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC;IACjC,CAAC;;AAjBD,kBAAkB;AACF,oBAAS,GAAG,YAAY,CAAC;SAF9B,UAAU;AAoBvB,aAAa,CAAC,aAAa,CAAC,UAAU,CAAC,CAAC;AAExC;;GAEG;AACH,MAAa,KAAM,SAAQ,UAAU;IAGnC;;;;;;OAMG;IACH,KAAK,CAAC,CAAS,EAAE,KAAK,GAAG,CAAC;QACxB,OAAO,IAAI,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,OAAO,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAChE,CAAC;;AAXD,kBAAkB;AACF,eAAS,GAAG,OAAO,CAAC;SAFzB,KAAK;AAclB,aAAa,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC;AAEnC;;GAEG;AACH,MAAa,IAAK,SAAQ,UAAU;IAGlC;;;;;OAKG;IACH,KAAK,CAAC,CAAS;QACb,OAAO,IAAI,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,GAAG,CAAC,IAAI,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC;;AAVD,kBAAkB;AACF,cAAS,GAAG,MAAM,CAAC;SAFxB,IAAI;AAajB,aAAa,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;AAElC,MAAM,UAAU,mBAAmB,CAAC,UAAsB;IACxD,OAAO,UAAU,CAAC,YAAY,EAAE,CAAC;AACnC,CAAC;AAED,MAAM,UAAU,qBAAqB,CACjC,MAAgC,EAChC,gBAA0C,EAAE;IAC9C,OAAO,sBAAsB,CACzB,MAAM,EAAE,aAAa,CAAC,gBAAgB,CAAC,MAAM,EAAE,CAAC,YAAY,EAC5D,aAAa,EAAE,YAAY,CAAC,CAAC;AACnC,CAAC;AAED,MAAM,UAAU,aAAa,CAAC,UACmC;IAC/D,IAAI,UAAU,IAAI,IAAI,EAAE;QACtB,MAAM,MAAM,GAA6B,EAAE,CAAC;QAC5C,MAAM,CAAC,WAAW,CAAC,GAAG,QAAQ,CAAC;QAC/B,MAAM,CAAC,QAAQ,CAAC,GAAG,EAAE,CAAC;QACtB,OAAO,qBAAqB,CAAC,MAAM,CAAC,CAAC;KACtC;IACD,IAAI,OAAO,UAAU,KAAK,QAAQ,EAAE;QAClC,MAAM,MAAM,GAA6B,EAAE,CAAC;QAC5C,MAAM,CAAC,WAAW,CAAC,GAAG,UAAU,CAAC;QACjC,MAAM,CAAC,QAAQ,CAAC,GAAG,EAAE,CAAC;QACtB,OAAO,qBAAqB,CAAC,MAAM,CAAC,CAAC;KACtC;SAAM,IAAI,UAAU,YAAY,UAAU,EAAE;QAC3C,OAAO,UAAU,CAAC;KACnB;SAAM;QACL,OAAO,qBAAqB,CAAC,UAAU,CAAC,CAAC;KAC1C;AACH,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n// Layer activation functions\nimport * as tfc from '@tensorflow/tfjs-core';\nimport {serialization, Tensor, tidy} from '@tensorflow/tfjs-core';\nimport * as K from './backend/tfjs_backend';\nimport {ActivationIdentifier} from './keras_format/activation_config';\nimport {deserializeKerasObject} from './utils/generic_utils';\n\n/**\n * Base class for Activations.\n *\n * Special note: due to cross-language compatibility reasons, the\n * static readonly className field in this family of classes must be set to\n * the initialLowerCamelCase name of the activation.\n */\nexport abstract class Activation extends serialization.Serializable {\n  abstract apply(tensor: Tensor, axis?: number): Tensor;\n  getConfig(): serialization.ConfigDict {\n    return {};\n  }\n}\n\n/**\n * Exponential linear unit (ELU).\n * Reference: https://arxiv.org/abs/1511.07289\n */\nexport class Elu extends Activation {\n  /** @nocollapse */\n  static readonly className = 'elu';\n  /**\n   * Calculate the activation function.\n   *\n   * @param x: Input.\n   * @param alpha: Scaling factor the negative section.\n   * @return Output of the ELU activation.\n   */\n  apply(x: Tensor, alpha = 1): Tensor {\n    return K.elu(x, alpha);\n  }\n}\nserialization.registerClass(Elu);\n\n/**\n * Scaled Exponential Linear Unit. (Klambauer et al., 2017).\n * Reference: Self-Normalizing Neural Networks, https://arxiv.org/abs/1706.02515\n * Notes:\n *   - To be used together with the initialization \"lecunNormal\".\n *   - To be used together with the dropout variant \"AlphaDropout\".\n */\nexport class Selu extends Activation {\n  /** @nocollapse */\n  static readonly className = 'selu';\n  apply(x: Tensor): Tensor {\n    return tfc.selu(x);\n  }\n}\nserialization.registerClass(Selu);\n\n/**\n *  Rectified linear unit\n */\nexport class Relu extends Activation {\n  /** @nocollapse */\n  static readonly className = 'relu';\n  apply(x: Tensor): Tensor {\n    return tfc.relu(x);\n  }\n}\nserialization.registerClass(Relu);\n\n/**\n * Rectified linear unit activation maxing out at 6.0.\n */\nexport class Relu6 extends Activation {\n  /** @nocollapse */\n  static readonly className = 'relu6';\n  apply(x: Tensor): Tensor {\n    return tidy(() => tfc.minimum(6.0, tfc.relu(x)));\n  }\n}\nserialization.registerClass(Relu6);\n\n//* Linear activation (no-op) */\nexport class Linear extends Activation {\n  /** @nocollapse */\n  static readonly className = 'linear';\n  apply(x: Tensor): Tensor {\n    return x;\n  }\n}\nserialization.registerClass(Linear);\n\n/**\n * Sigmoid activation function.\n */\nexport class Sigmoid extends Activation {\n  /** @nocollapse */\n  static readonly className = 'sigmoid';\n  apply(x: Tensor): Tensor {\n    return tfc.sigmoid(x);\n  }\n}\nserialization.registerClass(Sigmoid);\n\n/**\n * Segment-wise linear approximation of sigmoid.\n */\nexport class HardSigmoid extends Activation {\n  /** @nocollapse */\n  static readonly className = 'hardSigmoid';\n  apply(x: Tensor): Tensor {\n    return K.hardSigmoid(x);\n  }\n}\nserialization.registerClass(HardSigmoid);\n\n/**\n * Softplus activation function.\n */\nexport class Softplus extends Activation {\n  /** @nocollapse */\n  static readonly className = 'softplus';\n  apply(x: Tensor): Tensor {\n    return tfc.softplus(x);\n  }\n}\nserialization.registerClass(Softplus);\n\n/**\n * Softsign activation function.\n */\nexport class Softsign extends Activation {\n  /** @nocollapse */\n  static readonly className = 'softsign';\n  apply(x: Tensor): Tensor {\n    return K.softsign(x);\n  }\n}\nserialization.registerClass(Softsign);\n\n/**\n * Hyperbolic tangent function.\n */\nexport class Tanh extends Activation {\n  /** @nocollapse */\n  static readonly className = 'tanh';\n  apply(x: Tensor): Tensor {\n    return tfc.tanh(x);\n  }\n}\nserialization.registerClass(Tanh);\n\n/**\n * Softmax activation function\n */\nexport class Softmax extends Activation {\n  /** @nocollapse */\n  static readonly className = 'softmax';\n  /**\n   * Calculate the activation function.\n   *\n   * @param x Tensor.\n   * @param axis Integer, axis along which the softmax normalization is applied.\n   * Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be\n   * an error.\n   *\n   * @returns a Tensor of the same shape as x\n   *\n   * @throws ValueError: In case `dim(x) < 2`.\n   */\n  apply(x: Tensor, axis: number = (-1)): Tensor {\n    return tfc.softmax(x, axis);\n  }\n}\nserialization.registerClass(Softmax);\n\n/**\n * Log softmax activation function\n */\nexport class LogSoftmax extends Activation {\n  /** @nocollapse */\n  static readonly className = 'logSoftmax';\n  /**\n   * Calculate the activation function of log softmax:\n   * log( exp(x_i) / sum(exp(x)) )\n   *\n   * @param x Tensor.\n   * @param axis Integer, axis along which the softmax normalization is applied.\n   * Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be\n   * an error.\n   *\n   * @returns a Tensor of the same shape as x\n   *\n   * @throws ValueError: In case `dim(x) < 2`.\n   */\n  apply(x: Tensor, axis: number = (-1)): Tensor {\n    return tfc.logSoftmax(x, axis);\n  }\n}\nserialization.registerClass(LogSoftmax);\n\n/**\n * Swish activation function\n */\nexport class Swish extends Activation {\n  /** @nocollapse */\n  static readonly className = 'swish';\n  /**\n   * Calculate the activation function.\n   *\n   * @param x Tensor.\n   * @param alpha Scaling factor for the sigmoid function.\n   * @returns a Tensor of the same shape as x\n   */\n  apply(x: Tensor, alpha = 1): Tensor {\n    return tidy(() => tfc.mul(tfc.sigmoid(tfc.mul(x, alpha)), x));\n  }\n}\nserialization.registerClass(Swish);\n\n/**\n * Mish activation function\n */\nexport class Mish extends Activation {\n  /** @nocollapse */\n  static readonly className = 'mish';\n  /**\n   * Calculate the activation function.\n   *\n   * @param x Tensor.\n   * @returns a Tensor of the same shape as x\n   */\n  apply(x: Tensor): Tensor {\n    return tidy(() => tfc.mul(x, tfc.tanh(tfc.softplus(x))));\n  }\n}\nserialization.registerClass(Mish);\n\nexport function serializeActivation(activation: Activation): string {\n  return activation.getClassName();\n}\n\nexport function deserializeActivation(\n    config: serialization.ConfigDict,\n    customObjects: serialization.ConfigDict = {}): Activation {\n  return deserializeKerasObject(\n      config, serialization.SerializationMap.getMap().classNameMap,\n      customObjects, 'activation');\n}\n\nexport function getActivation(identifier: ActivationIdentifier|\n                              serialization.ConfigDict|Activation): Activation {\n  if (identifier == null) {\n    const config: serialization.ConfigDict = {};\n    config['className'] = 'linear';\n    config['config'] = {};\n    return deserializeActivation(config);\n  }\n  if (typeof identifier === 'string') {\n    const config: serialization.ConfigDict = {};\n    config['className'] = identifier;\n    config['config'] = {};\n    return deserializeActivation(config);\n  } else if (identifier instanceof Activation) {\n    return identifier;\n  } else {\n    return deserializeActivation(identifier);\n  }\n}\n"]} | ||
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"activations.js","sourceRoot":"","sources":["../../../../../tfjs-layers/src/activations.ts"],"names":[],"mappings":"AAAA;;;;;;;;GAQG;AAEH,6BAA6B;AAC7B,OAAO,KAAK,GAAG,MAAM,uBAAuB,CAAC;AAC7C,OAAO,EAAC,aAAa,EAAU,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAClE,OAAO,KAAK,CAAC,MAAM,wBAAwB,CAAC;AAE5C,OAAO,EAAC,sBAAsB,EAAC,MAAM,uBAAuB,CAAC;AAE7D;;;;;;GAMG;AACH,MAAM,OAAgB,UAAW,SAAQ,aAAa,CAAC,YAAY;IAEjE,SAAS;QACP,OAAO,EAAE,CAAC;IACZ,CAAC;CACF;AAED;;;GAGG;AACH,MAAa,GAAI,SAAQ,UAAU;IAGjC;;;;;;OAMG;IACH,KAAK,CAAC,CAAS,EAAE,KAAK,GAAG,CAAC;QACxB,OAAO,CAAC,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,CAAC;IACzB,CAAC;;AAXD,kBAAkB;AACF,aAAS,GAAG,KAAK,CAAC;SAFvB,GAAG;AAchB,aAAa,CAAC,aAAa,CAAC,GAAG,CAAC,CAAC;AAEjC;;;;;;GAMG;AACH,MAAa,IAAK,SAAQ,UAAU;IAGlC,KAAK,CAAC,CAAS;QACb,OAAO,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACrB,CAAC;;AAJD,kBAAkB;AACF,cAAS,GAAG,MAAM,CAAC;SAFxB,IAAI;AAOjB,aAAa,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;AAElC;;GAEG;AACH,MAAa,IAAK,SAAQ,UAAU;IAGlC,KAAK,CAAC,CAAS;QACb,OAAO,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACrB,CAAC;;AAJD,kBAAkB;AACF,cAAS,GAAG,MAAM,CAAC;SAFxB,IAAI;AAOjB,aAAa,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;AAElC;;GAEG;AACH,MAAa,KAAM,SAAQ,UAAU;IAGnC,KAAK,CAAC,CAAS;QACb,OAAO,IAAI,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,OAAO,CAAC,GAAG,EAAE,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACnD,CAAC;;AAJD,kBAAkB;AACF,eAAS,GAAG,OAAO,CAAC;SAFzB,KAAK;AAOlB,aAAa,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC;AAEnC,gCAAgC;AAChC,MAAa,MAAO,SAAQ,UAAU;IAGpC,KAAK,CAAC,CAAS;QACb,OAAO,CAAC,CAAC;IACX,CAAC;;AAJD,kBAAkB;AACF,gBAAS,GAAG,QAAQ,CAAC;SAF1B,MAAM;AAOnB,aAAa,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;AAEpC;;GAEG;AACH,MAAa,OAAQ,SAAQ,UAAU;IAGrC,KAAK,CAAC,CAAS;QACb,OAAO,GAAG,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IACxB,CAAC;;AAJD,kBAAkB;AACF,iBAAS,GAAG,SAAS,CAAC;SAF3B,OAAO;AAOpB,aAAa,CAAC,aAAa,CAAC,OAAO,CAAC,CAAC;AAErC;;GAEG;AACH,MAAa,WAAY,SAAQ,UAAU;IAGzC,KAAK,CAAC,CAAS;QACb,OAAO,CAAC,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC;IAC1B,CAAC;;AAJD,kBAAkB;AACF,qBAAS,GAAG,aAAa,CAAC;SAF/B,WAAW;AAOxB,aAAa,CAAC,aAAa,CAAC,WAAW,CAAC,CAAC;AAEzC;;GAEG;AACH,MAAa,QAAS,SAAQ,UAAU;IAGtC,KAAK,CAAC,CAAS;QACb,OAAO,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;IACzB,CAAC;;AAJD,kBAAkB;AACF,kBAAS,GAAG,UAAU,CAAC;SAF5B,QAAQ;AAOrB,aAAa,CAAC,aAAa,CAAC,QAAQ,CAAC,CAAC;AAEtC;;GAEG;AACH,MAAa,QAAS,SAAQ,UAAU;IAGtC,KAAK,CAAC,CAAS;QACb,OAAO,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC;IACvB,CAAC;;AAJD,kBAAkB;AACF,kBAAS,GAAG,UAAU,CAAC;SAF5B,QAAQ;AAOrB,aAAa,CAAC,aAAa,CAAC,QAAQ,CAAC,CAAC;AAEtC;;GAEG;AACH,MAAa,IAAK,SAAQ,UAAU;IAGlC,KAAK,CAAC,CAAS;QACb,OAAO,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACrB,CAAC;;AAJD,kBAAkB;AACF,cAAS,GAAG,MAAM,CAAC;SAFxB,IAAI;AAOjB,aAAa,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;AAElC;;GAEG;AACH,MAAa,OAAQ,SAAQ,UAAU;IAGrC;;;;;;;;;;;OAWG;IACH,KAAK,CAAC,CAAS,EAAE,OAAe,CAAC,CAAC,CAAC,CAAC;QAClC,OAAO,GAAG,CAAC,OAAO,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC;IAC9B,CAAC;;AAhBD,kBAAkB;AACF,iBAAS,GAAG,SAAS,CAAC;SAF3B,OAAO;AAmBpB,aAAa,CAAC,aAAa,CAAC,OAAO,CAAC,CAAC;AAErC;;GAEG;AACH,MAAa,UAAW,SAAQ,UAAU;IAGxC;;;;;;;;;;;;OAYG;IACH,KAAK,CAAC,CAAS,EAAE,OAAe,CAAC,CAAC,CAAC,CAAC;QAClC,OAAO,GAAG,CAAC,UAAU,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC;IACjC,CAAC;;AAjBD,kBAAkB;AACF,oBAAS,GAAG,YAAY,CAAC;SAF9B,UAAU;AAoBvB,aAAa,CAAC,aAAa,CAAC,UAAU,CAAC,CAAC;AAExC;;GAEG;AACH,MAAa,IAAK,SAAQ,UAAU;IAGlC;;;;;OAKG;IACH,KAAK,CAAC,CAAS;QACb,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,OAAO,GAAG,CAAC,IAAI,CAAC,GAAG,EAAE;gBACnB,MAAM,OAAO,GAAG,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;gBAC7B,sCAAsC;gBACtC,MAAM,GAAG,GAAG,GAAG,CAAC,GAAG,CAAC,GAAG,EAAE,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC;gBACnE,6BAA6B;gBAC7B,OAAO,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;YACzB,CAAC,CAAC,CAAC;QACL,CAAC,CAAC,CAAC;IACL,CAAC;;AAlBD,kBAAkB;AACF,cAAS,GAAG,MAAM,CAAC;SAFxB,IAAI;AAqBjB,aAAa,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;AAElC;;GAEG;AACH,MAAa,OAAQ,SAAQ,UAAU;IAGrC;;;;;OAKG;IACH,KAAK,CAAC,CAAS;QACb,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,OAAO,GAAG,CAAC,GAAG,CACZ,GAAG,EACH,GAAG,CAAC,GAAG,CACL,CAAC,EACD,GAAG,CAAC,GAAG,CACH,CAAC,EACD,GAAG,CAAC,IAAI,CACN,GAAG,CAAC,GAAG,CACL,GAAG,CAAC,IAAI,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,EAAE,CAAC,CAAC,EAC7B,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,GAAG,CAAC,GAAG,CAAC,QAAQ,EAAE,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAC3C,CACJ,CACJ,CACF,CACF,CAAC;QACJ,CAAC,CAAC,CAAC;IACL,CAAC;;AA1BD,kBAAkB;AACF,iBAAS,GAAG,UAAU,CAAC;SAF5B,OAAO;AA6BpB,aAAa,CAAC,aAAa,CAAC,OAAO,CAAC,CAAC;AAErC;;GAEG;AACH,MAAa,IAAK,SAAQ,UAAU;IAGlC;;;;;OAKG;IACH,KAAK,CAAC,CAAS;QACb,OAAO,IAAI,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,GAAG,CAAC,IAAI,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC;;AAVD,kBAAkB;AACF,cAAS,GAAG,MAAM,CAAC;SAFxB,IAAI;AAajB,aAAa,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;AAElC;;GAEG;AACH,MAAa,KAAM,SAAQ,UAAU;IAGnC;;;;;;OAMG;IACH,KAAK,CAAC,CAAS,EAAE,KAAK,GAAG,CAAC;QACxB,OAAO,IAAI,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,OAAO,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAChE,CAAC;;AAXD,kBAAkB;AACF,eAAS,GAAG,OAAO,CAAC;SAFzB,KAAK;AAclB,aAAa,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC;AAEnC,MAAM,UAAU,mBAAmB,CAAC,UAAsB;IACxD,OAAO,UAAU,CAAC,YAAY,EAAE,CAAC;AACnC,CAAC;AAED,MAAM,UAAU,qBAAqB,CACjC,MAAgC,EAChC,gBAA0C,EAAE;IAC9C,OAAO,sBAAsB,CACzB,MAAM,EAAE,aAAa,CAAC,gBAAgB,CAAC,MAAM,EAAE,CAAC,YAAY,EAC5D,aAAa,EAAE,YAAY,CAAC,CAAC;AACnC,CAAC;AAED,MAAM,UAAU,aAAa,CAAC,UACmC;IAC/D,IAAI,UAAU,IAAI,IAAI,EAAE;QACtB,MAAM,MAAM,GAA6B,EAAE,CAAC;QAC5C,MAAM,CAAC,WAAW,CAAC,GAAG,QAAQ,CAAC;QAC/B,MAAM,CAAC,QAAQ,CAAC,GAAG,EAAE,CAAC;QACtB,OAAO,qBAAqB,CAAC,MAAM,CAAC,CAAC;KACtC;IACD,IAAI,OAAO,UAAU,KAAK,QAAQ,EAAE;QAClC,MAAM,MAAM,GAA6B,EAAE,CAAC;QAC5C,MAAM,CAAC,WAAW,CAAC,GAAG,UAAU,CAAC;QACjC,MAAM,CAAC,QAAQ,CAAC,GAAG,EAAE,CAAC;QACtB,OAAO,qBAAqB,CAAC,MAAM,CAAC,CAAC;KACtC;SAAM,IAAI,UAAU,YAAY,UAAU,EAAE;QAC3C,OAAO,UAAU,CAAC;KACnB;SAAM;QACL,OAAO,qBAAqB,CAAC,UAAU,CAAC,CAAC;KAC1C;AACH,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC\n *\n * Use of this source code is governed by an MIT-style\n * license that can be found in the LICENSE file or at\n * https://opensource.org/licenses/MIT.\n * =============================================================================\n */\n\n// Layer activation functions\nimport * as tfc from '@tensorflow/tfjs-core';\nimport {serialization, Tensor, tidy} from '@tensorflow/tfjs-core';\nimport * as K from './backend/tfjs_backend';\nimport {ActivationIdentifier} from './keras_format/activation_config';\nimport {deserializeKerasObject} from './utils/generic_utils';\n\n/**\n * Base class for Activations.\n *\n * Special note: due to cross-language compatibility reasons, the\n * static readonly className field in this family of classes must be set to\n * the initialLowerCamelCase name of the activation.\n */\nexport abstract class Activation extends serialization.Serializable {\n  abstract apply(tensor: Tensor, axis?: number): Tensor;\n  getConfig(): serialization.ConfigDict {\n    return {};\n  }\n}\n\n/**\n * Exponential linear unit (ELU).\n * Reference: https://arxiv.org/abs/1511.07289\n */\nexport class Elu extends Activation {\n  /** @nocollapse */\n  static readonly className = 'elu';\n  /**\n   * Calculate the activation function.\n   *\n   * @param x: Input.\n   * @param alpha: Scaling factor the negative section.\n   * @return Output of the ELU activation.\n   */\n  apply(x: Tensor, alpha = 1): Tensor {\n    return K.elu(x, alpha);\n  }\n}\nserialization.registerClass(Elu);\n\n/**\n * Scaled Exponential Linear Unit. (Klambauer et al., 2017).\n * Reference: Self-Normalizing Neural Networks, https://arxiv.org/abs/1706.02515\n * Notes:\n *   - To be used together with the initialization \"lecunNormal\".\n *   - To be used together with the dropout variant \"AlphaDropout\".\n */\nexport class Selu extends Activation {\n  /** @nocollapse */\n  static readonly className = 'selu';\n  apply(x: Tensor): Tensor {\n    return tfc.selu(x);\n  }\n}\nserialization.registerClass(Selu);\n\n/**\n *  Rectified linear unit\n */\nexport class Relu extends Activation {\n  /** @nocollapse */\n  static readonly className = 'relu';\n  apply(x: Tensor): Tensor {\n    return tfc.relu(x);\n  }\n}\nserialization.registerClass(Relu);\n\n/**\n * Rectified linear unit activation maxing out at 6.0.\n */\nexport class Relu6 extends Activation {\n  /** @nocollapse */\n  static readonly className = 'relu6';\n  apply(x: Tensor): Tensor {\n    return tidy(() => tfc.minimum(6.0, tfc.relu(x)));\n  }\n}\nserialization.registerClass(Relu6);\n\n//* Linear activation (no-op) */\nexport class Linear extends Activation {\n  /** @nocollapse */\n  static readonly className = 'linear';\n  apply(x: Tensor): Tensor {\n    return x;\n  }\n}\nserialization.registerClass(Linear);\n\n/**\n * Sigmoid activation function.\n */\nexport class Sigmoid extends Activation {\n  /** @nocollapse */\n  static readonly className = 'sigmoid';\n  apply(x: Tensor): Tensor {\n    return tfc.sigmoid(x);\n  }\n}\nserialization.registerClass(Sigmoid);\n\n/**\n * Segment-wise linear approximation of sigmoid.\n */\nexport class HardSigmoid extends Activation {\n  /** @nocollapse */\n  static readonly className = 'hardSigmoid';\n  apply(x: Tensor): Tensor {\n    return K.hardSigmoid(x);\n  }\n}\nserialization.registerClass(HardSigmoid);\n\n/**\n * Softplus activation function.\n */\nexport class Softplus extends Activation {\n  /** @nocollapse */\n  static readonly className = 'softplus';\n  apply(x: Tensor): Tensor {\n    return tfc.softplus(x);\n  }\n}\nserialization.registerClass(Softplus);\n\n/**\n * Softsign activation function.\n */\nexport class Softsign extends Activation {\n  /** @nocollapse */\n  static readonly className = 'softsign';\n  apply(x: Tensor): Tensor {\n    return K.softsign(x);\n  }\n}\nserialization.registerClass(Softsign);\n\n/**\n * Hyperbolic tangent function.\n */\nexport class Tanh extends Activation {\n  /** @nocollapse */\n  static readonly className = 'tanh';\n  apply(x: Tensor): Tensor {\n    return tfc.tanh(x);\n  }\n}\nserialization.registerClass(Tanh);\n\n/**\n * Softmax activation function\n */\nexport class Softmax extends Activation {\n  /** @nocollapse */\n  static readonly className = 'softmax';\n  /**\n   * Calculate the activation function.\n   *\n   * @param x Tensor.\n   * @param axis Integer, axis along which the softmax normalization is applied.\n   * Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be\n   * an error.\n   *\n   * @returns a Tensor of the same shape as x\n   *\n   * @throws ValueError: In case `dim(x) < 2`.\n   */\n  apply(x: Tensor, axis: number = (-1)): Tensor {\n    return tfc.softmax(x, axis);\n  }\n}\nserialization.registerClass(Softmax);\n\n/**\n * Log softmax activation function\n */\nexport class LogSoftmax extends Activation {\n  /** @nocollapse */\n  static readonly className = 'logSoftmax';\n  /**\n   * Calculate the activation function of log softmax:\n   * log( exp(x_i) / sum(exp(x)) )\n   *\n   * @param x Tensor.\n   * @param axis Integer, axis along which the softmax normalization is applied.\n   * Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be\n   * an error.\n   *\n   * @returns a Tensor of the same shape as x\n   *\n   * @throws ValueError: In case `dim(x) < 2`.\n   */\n  apply(x: Tensor, axis: number = (-1)): Tensor {\n    return tfc.logSoftmax(x, axis);\n  }\n}\nserialization.registerClass(LogSoftmax);\n\n/**\n * Gelu activation function\n */\nexport class Gelu extends Activation {\n  /** @nocollapse */\n  static readonly className = 'gelu';\n  /**\n   * Calculate the activation function.\n   *\n   * @param x Tensor.\n   * @returns a Tensor of the same shape as x\n   */\n  apply(x: Tensor): Tensor {\n    return tidy(() => {\n      return tfc.tidy(() => {\n        const sqrtTwo = Math.sqrt(2);\n        // Compute Φ(x) using the erf function\n        const cdf = tfc.mul(0.5, tfc.add(1, tfc.erf(tfc.div(x, sqrtTwo))));\n        // Compute GELU(x) = x * Φ(x)\n        return tfc.mul(x, cdf);\n      });\n    });\n  }\n}\nserialization.registerClass(Gelu);\n\n/**\n * GeluNew activation function\n */\nexport class GeluNew extends Activation {\n  /** @nocollapse */\n  static readonly className = 'gelu_new';\n  /**\n   * Calculate the activation function.\n   *\n   * @param x Tensor.\n   * @returns a Tensor of the same shape as x\n   */\n  apply(x: Tensor): Tensor {\n    return tidy(() => {\n      return tfc.mul(\n        0.5,\n        tfc.mul(\n          x,\n          tfc.add(\n              1,\n              tfc.tanh(\n                tfc.mul(\n                  tfc.sqrt(tfc.div(2, Math.PI)),\n                  tfc.add(x, tfc.mul(0.044715, tfc.pow(x, 3)))\n                  )\n              )\n          )\n        )\n      );\n    });\n  }\n}\nserialization.registerClass(GeluNew);\n\n/**\n * Mish activation function\n */\nexport class Mish extends Activation {\n  /** @nocollapse */\n  static readonly className = 'mish';\n  /**\n   * Calculate the activation function.\n   *\n   * @param x Tensor.\n   * @returns a Tensor of the same shape as x\n   */\n  apply(x: Tensor): Tensor {\n    return tidy(() => tfc.mul(x, tfc.tanh(tfc.softplus(x))));\n  }\n}\nserialization.registerClass(Mish);\n\n/**\n * Swish activation function\n */\nexport class Swish extends Activation {\n  /** @nocollapse */\n  static readonly className = 'swish';\n  /**\n   * Calculate the activation function.\n   *\n   * @param x Tensor.\n   * @param alpha Scaling factor for the sigmoid function.\n   * @returns a Tensor of the same shape as x\n   */\n  apply(x: Tensor, alpha = 1): Tensor {\n    return tidy(() => tfc.mul(tfc.sigmoid(tfc.mul(x, alpha)), x));\n  }\n}\nserialization.registerClass(Swish);\n\nexport function serializeActivation(activation: Activation): string {\n  return activation.getClassName();\n}\n\nexport function deserializeActivation(\n    config: serialization.ConfigDict,\n    customObjects: serialization.ConfigDict = {}): Activation {\n  return deserializeKerasObject(\n      config, serialization.SerializationMap.getMap().classNameMap,\n      customObjects, 'activation');\n}\n\nexport function getActivation(identifier: ActivationIdentifier|\n                              serialization.ConfigDict|Activation): Activation {\n  if (identifier == null) {\n    const config: serialization.ConfigDict = {};\n    config['className'] = 'linear';\n    config['config'] = {};\n    return deserializeActivation(config);\n  }\n  if (typeof identifier === 'string') {\n    const config: serialization.ConfigDict = {};\n    config['className'] = identifier;\n    config['config'] = {};\n    return deserializeActivation(config);\n  } else if (identifier instanceof Activation) {\n    return identifier;\n  } else {\n    return deserializeActivation(identifier);\n  }\n}\n"]} |
@@ -14,3 +14,3 @@ /** | ||
*/ | ||
export declare const activationOptions: ("linear" | "relu" | "elu" | "relu6" | "sigmoid" | "hard_sigmoid" | "selu" | "softmax" | "softplus" | "softsign" | "tanh" | "swish" | "mish")[]; | ||
export declare const activationOptions: ("linear" | "relu" | "elu" | "relu6" | "sigmoid" | "hard_sigmoid" | "selu" | "softmax" | "softplus" | "softsign" | "tanh" | "swish" | "mish" | "gelu" | "gelu_new")[]; | ||
/** | ||
@@ -21,2 +21,2 @@ * A type representing the strings that are valid loss names. | ||
/** @docinline */ | ||
export type ActivationIdentifier = 'elu' | 'hardSigmoid' | 'linear' | 'relu' | 'relu6' | 'selu' | 'sigmoid' | 'softmax' | 'softplus' | 'softsign' | 'tanh' | 'swish' | 'mish'; | ||
export type ActivationIdentifier = 'elu' | 'hardSigmoid' | 'linear' | 'relu' | 'relu6' | 'selu' | 'sigmoid' | 'softmax' | 'softplus' | 'softsign' | 'tanh' | 'swish' | 'mish' | 'gelu' | 'gelu_new'; |
@@ -16,4 +16,4 @@ /** | ||
'elu', 'hard_sigmoid', 'linear', 'relu', 'relu6', 'selu', 'sigmoid', | ||
'softmax', 'softplus', 'softsign', 'tanh', 'swish', 'mish' | ||
'softmax', 'softplus', 'softsign', 'tanh', 'swish', 'mish', 'gelu', 'gelu_new' | ||
]); | ||
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYWN0aXZhdGlvbl9jb25maWcuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWxheWVycy9zcmMva2VyYXNfZm9ybWF0L2FjdGl2YXRpb25fY29uZmlnLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7OztHQVFHO0FBRUgsT0FBTyxFQUFDLGtCQUFrQixFQUFDLE1BQU0sU0FBUyxDQUFDO0FBRTNDOztHQUVHO0FBQ0gsTUFBTSxDQUFDLE1BQU0saUJBQWlCLEdBQUcsa0JBQWtCLENBQUM7SUFDbEQsS0FBSyxFQUFFLGNBQWMsRUFBRSxRQUFRLEVBQUUsTUFBTSxFQUFFLE9BQU8sRUFBRSxNQUFNLEVBQUUsU0FBUztJQUNuRSxTQUFTLEVBQUUsVUFBVSxFQUFFLFVBQVUsRUFBRSxNQUFNLEVBQUUsT0FBTyxFQUFFLE1BQU07Q0FDM0QsQ0FBQyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTggR29vZ2xlIExMQ1xuICpcbiAqIFVzZSBvZiB0aGlzIHNvdXJjZSBjb2RlIGlzIGdvdmVybmVkIGJ5IGFuIE1JVC1zdHlsZVxuICogbGljZW5zZSB0aGF0IGNhbiBiZSBmb3VuZCBpbiB0aGUgTElDRU5TRSBmaWxlIG9yIGF0XG4gKiBodHRwczovL29wZW5zb3VyY2Uub3JnL2xpY2Vuc2VzL01JVC5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtzdHJpbmdMaXRlcmFsQXJyYXl9IGZyb20gJy4vdXRpbHMnO1xuXG4vKipcbiAqIExpc3Qgb2YgYWxsIGtub3duIGFjdGl2YXRpb24gbmFtZXMuXG4gKi9cbmV4cG9ydCBjb25zdCBhY3RpdmF0aW9uT3B0aW9ucyA9IHN0cmluZ0xpdGVyYWxBcnJheShbXG4gICdlbHUnLCAnaGFyZF9zaWdtb2lkJywgJ2xpbmVhcicsICdyZWx1JywgJ3JlbHU2JywgJ3NlbHUnLCAnc2lnbW9pZCcsXG4gICdzb2Z0bWF4JywgJ3NvZnRwbHVzJywgJ3NvZnRzaWduJywgJ3RhbmgnLCAnc3dpc2gnLCAnbWlzaCdcbl0pO1xuXG4vKipcbiAqIEEgdHlwZSByZXByZXNlbnRpbmcgdGhlIHN0cmluZ3MgdGhhdCBhcmUgdmFsaWQgbG9zcyBuYW1lcy5cbiAqL1xuZXhwb3J0IHR5cGUgQWN0aXZhdGlvblNlcmlhbGl6YXRpb24gPSB0eXBlb2YgYWN0aXZhdGlvbk9wdGlvbnNbbnVtYmVyXTtcblxuLy8gU2FkIHRoYXQgd2UgaGF2ZSB0byBkbyBhbGwgdGhpcyBqdXN0IGZvciBoYXJkX3NpZ21vaWQgdnMuIGhhcmRTaWdtb2lkLlxuLy8gVE9ETyhzb2VyZ2VsKTogTW92ZSB0aGUgQ2FtZWxDYXNlIHZlcnNpb25zIGJhY2sgb3V0IG9mIGtlcmFzX2Zvcm1hdFxuLy8gZS5nLiB0byBzcmMvY29tbW9uLnRzLiAgTWF5YmUgZXZlbiBkdXBsaWNhdGUgKmFsbCogb2YgdGhlc2UgdG8gYmUgcGVkYW50aWM/XG4vKiogQGRvY2lubGluZSAqL1xuZXhwb3J0IHR5cGUgQWN0aXZhdGlvbklkZW50aWZpZXIgPSAnZWx1J3wnaGFyZFNpZ21vaWQnfCdsaW5lYXInfCdyZWx1J3wncmVsdTYnfFxuICAgICdzZWx1J3wnc2lnbW9pZCd8J3NvZnRtYXgnfCdzb2Z0cGx1cyd8J3NvZnRzaWduJ3wndGFuaCd8J3N3aXNoJ3wnbWlzaCc7XG4iXX0= | ||
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYWN0aXZhdGlvbl9jb25maWcuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWxheWVycy9zcmMva2VyYXNfZm9ybWF0L2FjdGl2YXRpb25fY29uZmlnLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7OztHQVFHO0FBRUgsT0FBTyxFQUFDLGtCQUFrQixFQUFDLE1BQU0sU0FBUyxDQUFDO0FBRTNDOztHQUVHO0FBQ0gsTUFBTSxDQUFDLE1BQU0saUJBQWlCLEdBQUcsa0JBQWtCLENBQUM7SUFDbEQsS0FBSyxFQUFFLGNBQWMsRUFBRSxRQUFRLEVBQUUsTUFBTSxFQUFFLE9BQU8sRUFBRSxNQUFNLEVBQUUsU0FBUztJQUNuRSxTQUFTLEVBQUUsVUFBVSxFQUFFLFVBQVUsRUFBRSxNQUFNLEVBQUUsT0FBTyxFQUFFLE1BQU0sRUFBRSxNQUFNLEVBQUUsVUFBVTtDQUMvRSxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxOCBHb29nbGUgTExDXG4gKlxuICogVXNlIG9mIHRoaXMgc291cmNlIGNvZGUgaXMgZ292ZXJuZWQgYnkgYW4gTUlULXN0eWxlXG4gKiBsaWNlbnNlIHRoYXQgY2FuIGJlIGZvdW5kIGluIHRoZSBMSUNFTlNFIGZpbGUgb3IgYXRcbiAqIGh0dHBzOi8vb3BlbnNvdXJjZS5vcmcvbGljZW5zZXMvTUlULlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge3N0cmluZ0xpdGVyYWxBcnJheX0gZnJvbSAnLi91dGlscyc7XG5cbi8qKlxuICogTGlzdCBvZiBhbGwga25vd24gYWN0aXZhdGlvbiBuYW1lcy5cbiAqL1xuZXhwb3J0IGNvbnN0IGFjdGl2YXRpb25PcHRpb25zID0gc3RyaW5nTGl0ZXJhbEFycmF5KFtcbiAgJ2VsdScsICdoYXJkX3NpZ21vaWQnLCAnbGluZWFyJywgJ3JlbHUnLCAncmVsdTYnLCAnc2VsdScsICdzaWdtb2lkJyxcbiAgJ3NvZnRtYXgnLCAnc29mdHBsdXMnLCAnc29mdHNpZ24nLCAndGFuaCcsICdzd2lzaCcsICdtaXNoJywgJ2dlbHUnLCAnZ2VsdV9uZXcnXG5dKTtcblxuLyoqXG4gKiBBIHR5cGUgcmVwcmVzZW50aW5nIHRoZSBzdHJpbmdzIHRoYXQgYXJlIHZhbGlkIGxvc3MgbmFtZXMuXG4gKi9cbmV4cG9ydCB0eXBlIEFjdGl2YXRpb25TZXJpYWxpemF0aW9uID0gdHlwZW9mIGFjdGl2YXRpb25PcHRpb25zW251bWJlcl07XG5cbi8vIFNhZCB0aGF0IHdlIGhhdmUgdG8gZG8gYWxsIHRoaXMganVzdCBmb3IgaGFyZF9zaWdtb2lkIHZzLiBoYXJkU2lnbW9pZC5cbi8vIFRPRE8oc29lcmdlbCk6IE1vdmUgdGhlIENhbWVsQ2FzZSB2ZXJzaW9ucyBiYWNrIG91dCBvZiBrZXJhc19mb3JtYXRcbi8vIGUuZy4gdG8gc3JjL2NvbW1vbi50cy4gIE1heWJlIGV2ZW4gZHVwbGljYXRlICphbGwqIG9mIHRoZXNlIHRvIGJlIHBlZGFudGljP1xuLyoqIEBkb2NpbmxpbmUgKi9cbmV4cG9ydCB0eXBlIEFjdGl2YXRpb25JZGVudGlmaWVyID0gJ2VsdSd8J2hhcmRTaWdtb2lkJ3wnbGluZWFyJ3wncmVsdSd8J3JlbHU2J3xcbiAgICAnc2VsdSd8J3NpZ21vaWQnfCdzb2Z0bWF4J3wnc29mdHBsdXMnfCdzb2Z0c2lnbid8J3RhbmgnfCdzd2lzaCd8J21pc2gnfCdnZWx1J3wnZ2VsdV9uZXcnO1xuIl19 |
@@ -108,4 +108,3 @@ /** | ||
layerNormEpsilon: 1e-05, | ||
// TODO(pforderique): Implement gelu. | ||
activation: getActivation('relu'), | ||
activation: getActivation('gelu'), | ||
kernelInitializer: gpt2KernelInitializer(0.02), | ||
@@ -158,2 +157,2 @@ normalizeFirst: true, | ||
serialization.registerClass(GPT2Backbone); | ||
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"gpt2_backbone.js","sourceRoot":"","sources":["../../../../../../../../../tfjs-layers/src/layers/nlp/models/gpt2/gpt2_backbone.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,6DAA6D;AAC7D,OAAO,EAAE,aAAa,EAAE,MAAM,uBAAuB,CAAC;AAEtD,OAAO,EAAE,YAAY,EAAE,MAAM,0BAA0B,CAAC;AACxD,OAAO,EAAE,KAAK,EAAE,MAAM,qBAAqB,CAAC;AAC5C,OAAO,EAAE,SAAS,EAAE,MAAM,qBAAqB,CAAC;AAEhD,OAAO,EAAE,iBAAiB,EAAE,MAAM,mCAAmC,CAAC;AACtE,OAAO,EAAE,GAAG,EAAE,MAAM,4BAA4B,CAAC;AACjD,OAAO,EAAE,OAAO,EAAE,MAAM,eAAe,CAAC;AACxC,OAAO,EAAE,kBAAkB,EAAE,MAAM,oCAAoC,CAAC;AACxE,OAAO,EAAE,aAAa,EAAE,MAAM,yBAAyB,CAAC;AACxD,OAAO,EAAE,kBAAkB,EAAE,MAAM,wBAAwB,CAAC;AAC5D,OAAO,EAAE,QAAQ,EAAE,MAAM,aAAa,CAAC;AAEvC,SAAS,qBAAqB,CAAC,MAAM,GAAG,IAAI;IAC1C,OAAO,IAAI,YAAY,CAAC,EAAC,MAAM,EAAC,CAAC,CAAC;AACpC,CAAC;AA6CD;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAwCG;AACH,MAAa,YAAa,SAAQ,QAAQ;IAYxC,YAAY,IAAsB;;QAChC,IAAI,CAAC,OAAO,GAAG,MAAA,IAAI,CAAC,OAAO,mCAAI,GAAG,CAAC;QACnC,IAAI,CAAC,iBAAiB,GAAG,MAAA,IAAI,CAAC,iBAAiB,mCAAI,IAAI,CAAC;QAExD,SAAS;QACT,MAAM,QAAQ,GAAG,KAAK,CAAC,EAAC,KAAK,EAAE,CAAC,IAAI,CAAC,EAAE,KAAK,EAAE,OAAO,EAAE,IAAI,EAAE,WAAW,EAAC,CAAC,CAAC;QAC3E,MAAM,WAAW,GACf,KAAK,CAAC,EAAC,KAAK,EAAE,CAAC,IAAI,CAAC,EAAE,KAAK,EAAE,OAAO,EAAE,IAAI,EAAE,cAAc,EAAC,CAAC,CAAC;QAE/D,2BAA2B;QAC3B,MAAM,cAAc,GAAG,IAAI,SAAS,CAAC;YACnC,QAAQ,EAAE,IAAI,CAAC,cAAc;YAC7B,SAAS,EAAE,IAAI,CAAC,SAAS;YACzB,qBAAqB,EAAE,qBAAqB,CAAC,IAAI,CAAC;YAClD,IAAI,EAAE,iBAAiB;SACxB,CAAC,CAAC,KAAK,CAAC,QAAQ,CAAmB,CAAC;QAErC,MAAM,iBAAiB,GAAG,IAAI,iBAAiB,CAAC;YAC9C,WAAW,EAAE,qBAAqB,CAAC,IAAI,CAAC;YACxC,cAAc,EAAE,IAAI,CAAC,iBAAiB;YACtC,IAAI,EAAE,oBAAoB;SAC3B,CAAC,CAAC,KAAK,CAAC,cAAc,CAAmB,CAAC;QAE3C,uCAAuC;QACvC,IAAI,CAAC,GAAG,GAAG,CAAC,EAAC,IAAI,EAAE,gBAAgB,EAAC,CAAC;aAClC,KAAK,CAAC,CAAC,cAAc,EAAE,iBAAiB,CAAC,CAAmB,CAAC;QAChE,CAAC,GAAG,IAAI,OAAO,CAAC,EAAC,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,IAAI,EAAE,oBAAoB,EAAC,CAAC;aAC9D,KAAK,CAAC,CAAC,CAAmB,CAAC;QAE9B,+CAA+C;QAC/C,KAAI,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,SAAS,EAAE,CAAC,EAAE,EAAE;YACtC,CAAC,GAAG,IAAI,kBAAkB,CAAC;gBACzB,eAAe,EAAE,IAAI,CAAC,eAAe;gBACrC,QAAQ,EAAE,IAAI,CAAC,QAAQ;gBACvB,OAAO,EAAE,IAAI,CAAC,OAAO;gBACrB,gBAAgB,EAAE,KAAK;gBACvB,qCAAqC;gBACrC,UAAU,EAAE,aAAa,CAAC,MAAM,CAAC;gBACjC,iBAAiB,EAAE,qBAAqB,CAAC,IAAI,CAAC;gBAC9C,cAAc,EAAE,IAAI;gBACpB,IAAI,EAAE,qBAAqB,CAAC,EAAE;aAC/B,CAAC,CAAC,KAAK,CAAC,CAAC,EAAE,EAAC,kBAAkB,EAAE,WAAW,EAAC,CAAmB,CAAC;SAClE;QAED,MAAM,cAAc,GAAG,IAAI,kBAAkB,CAAC;YAC5C,IAAI,EAAE,YAAY;YAClB,IAAI,EAAE,CAAC,CAAC;YACR,OAAO,EAAE,KAAK;YACd,KAAK,EAAE,SAAS;SACjB,CAAC,CAAC,KAAK,CAAC,CAAC,CAAmB,CAAC;QAE9B,sDAAsD;QACtD,KAAK,CAAC;YACJ,MAAM,EAAE,CAAC,QAAQ,EAAE,WAAW,CAAC;YAC/B,OAAO,EAAE,cAAc;YACvB,IAAI,EAAE,eAAe;SACtB,CAAC,CAAC;QACH,IAAI,CAAC,cAAc,GAAG,IAAI,CAAC,cAAc,CAAC;QAC1C,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,SAAS,CAAC;QAChC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAC9B,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,SAAS,CAAC;QAChC,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC,eAAe,CAAC;QAC5C,IAAI,CAAC,OAAO,GAAG,MAAA,IAAI,CAAC,OAAO,mCAAI,GAAG,CAAC;QACnC,IAAI,CAAC,iBAAiB,GAAG,MAAA,IAAI,CAAC,iBAAiB,mCAAI,IAAI,CAAC;IAC1D,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B;YACvC,cAAc,EAAE,IAAI,CAAC,cAAc;YACnC,SAAS,EAAE,IAAI,CAAC,SAAS;YACzB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,SAAS,EAAE,IAAI,CAAC,SAAS;YACzB,eAAe,EAAE,IAAI,CAAC,eAAe;YACrC,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;SAC1C,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;IAED,IAAa,cAAc;QACzB,OAAO,IAAI,CAAC,QAAQ,CAAC,iBAAiB,CAAc,CAAC;IACvD,CAAC;;AA9FD,kBAAkB;AACF,sBAAS,GAAG,cAAc,CAAC;SAFhC,YAAY;AAiGzB,aAAa,CAAC,aAAa,CAAC,YAAY,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Base class for Backbone models.\n */\n\n/* Original source: keras_nlp/models/gpt2/gpt2_backbone.py */\nimport { serialization } from '@tensorflow/tfjs-core';\n\nimport { RandomNormal } from '../../../../initializers';\nimport { input } from '../../../../exports';\nimport { Embedding } from '../../../embeddings';\nimport { SymbolicTensor } from '../../../../engine/topology';\nimport { PositionEmbedding } from '../../modeling/position_embedding';\nimport { add } from '../../../../exports_layers';\nimport { Dropout } from '../../../core';\nimport { TransformerDecoder } from '../../modeling/transformer_decoder';\nimport { getActivation } from '../../../../activations';\nimport { LayerNormalization } from '../../../normalization';\nimport { Backbone } from '../backbone';\n\nfunction gpt2KernelInitializer(stddev = 0.02) {\n  return new RandomNormal({stddev});\n}\n\nexport interface GPT2BackboneArgs  {\n  /**\n   * Integer. The size of the token vocabulary.\n   */\n  vocabularySize: number;\n\n  /**\n   * Integer. The number of transformer layers.\n   */\n  numLayers: number;\n\n  /**\n   * Integer. The number of attention heads for each transformer.\n   * The hidden size must be divisible by the number of attention heads.\n   */\n  numHeads: number;\n\n  /**\n   * Integer. The size of the transformer encoding and pooler layers.\n   */\n  hiddenDim: number;\n\n  /**\n   * Integer. The output dimension of the first Dense layer in a two-layer\n   * feedforward network for each transformer.\n   */\n  intermediateDim: number;\n\n  /**\n   * Float. Dropout probability for the Transformer encoder.\n   * Defaults to 0.2.\n   */\n  dropout?: number;\n\n  /**\n   * Integer. The maximum sequence length that this encoder can consume.\n   * If `null`, `maxSequenceLength` uses the value from sequence length.\n   * This determines the variable shape for positional embeddings.\n   * Defaults to 1024.\n   */\n  maxSequenceLength?: number;\n}\n\n/**\n * GPT-2 core network with hyperparameters.\n *\n * This network implements a Transformer-based decoder network,\n * Generative Pretrained Transformer-2 (GPT-2), as described in\n * [\"Language Models are Unsupervised Multitask Learners\"](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf).\n * It includes the embedding lookups and transformer layers.\n *\n * The default constructor gives a fully customizable, randomly initialized\n * GPT-2 model with any number of layers, heads, and embedding\n * dimensions. To load preset architectures and weights, use the `fromPreset`\n * constructor.\n *\n * Disclaimer: Pre-trained models are provided on an \"as is\" basis, without\n * warranties or conditions of any kind. The underlying model is provided by a\n * third party and subject to a separate license, available\n * [here](https://github.com/openai/gpt-2).\n *\n *\n * Example usage:\n * ```js\n * const tokenIds = tf.ones([1, 12]), dtype=\"int32\");\n * const paddingMask = tf.tensor(\n *  [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], 'int32');\n *\n * # Pretrained GPT-2 decoder.\n * model = GPT2Backbone.fromPreset(\"gpt2_base_en\");\n * model.apply(inputData, {paddingMask});\n *\n * # Randomly initialized GPT-2 decoder with custom config.\n * model = kerasNlp.models.GPT2Backbone({\n *     vocabularySize: 50257,\n *     numLayers: 12,\n *     numHeads: 12,\n *     hiddenDim: 768,\n *     intermediateDim: 3072,\n *     maxSequenceLength: 1024,\n * });\n * model.apply(inputData, {paddingMask});\n * ```\n */\nexport class GPT2Backbone extends Backbone {\n  /** @nocollapse */\n  static override className = 'GPT2Backbone';\n\n  private vocabularySize: number;\n  private numLayers: number;\n  private numHeads: number;\n  private hiddenDim: number;\n  private intermediateDim: number;\n  private dropout: number;\n  private maxSequenceLength: number;\n\n  constructor(args: GPT2BackboneArgs) {\n    args.dropout = args.dropout ?? 0.1;\n    args.maxSequenceLength = args.maxSequenceLength ?? 1024;\n\n    // Inputs\n    const tokenIds = input({shape: [null], dtype: 'int32', name: 'token_ids'});\n    const paddingMask =\n      input({shape: [null], dtype: 'int32', name: 'padding_mask'});\n\n    // Embed tokens, positions.\n    const tokenEmbedding = new Embedding({\n      inputDim: args.vocabularySize,\n      outputDim: args.hiddenDim,\n      embeddingsInitializer: gpt2KernelInitializer(0.01),\n      name: 'token_embedding',\n    }).apply(tokenIds) as SymbolicTensor;\n\n    const positionEmbedding = new PositionEmbedding({\n      initializer: gpt2KernelInitializer(0.02),\n      sequenceLength: args.maxSequenceLength,\n      name: 'position_embedding',\n    }).apply(tokenEmbedding) as SymbolicTensor;\n\n    // Sum and apply dropout to embeddings.\n    let x = add({name: 'embeddings_add'})\n      .apply([tokenEmbedding, positionEmbedding]) as SymbolicTensor;\n    x = new Dropout({rate: args.dropout, name: 'embeddings_dropout'})\n      .apply(x) as SymbolicTensor;\n\n    // Apply successive transformer decoder blocks.\n    for(let i = 0; i < args.numLayers; i++) {\n      x = new TransformerDecoder({\n        intermediateDim: args.intermediateDim,\n        numHeads: args.numHeads,\n        dropout: args.dropout,\n        layerNormEpsilon: 1e-05,\n        // TODO(pforderique): Implement gelu.\n        activation: getActivation('relu'),\n        kernelInitializer: gpt2KernelInitializer(0.02),\n        normalizeFirst: true,\n        name: `transformer_layer_${i}`,\n      }).apply(x, {decoderPaddingMask: paddingMask}) as SymbolicTensor;\n    }\n\n    const sequenceOutput = new LayerNormalization({\n      name: 'layer_norm',\n      axis: -1,\n      epsilon: 1e-05,\n      dtype: 'float32',\n    }).apply(x) as SymbolicTensor;\n\n    // Instantiate using Functional API Model constructor.\n    super({\n      inputs: [tokenIds, paddingMask],\n      outputs: sequenceOutput,\n      name: 'gpt2_backbone'\n    });\n    this.vocabularySize = args.vocabularySize;\n    this.numLayers = args.numLayers;\n    this.numHeads = args.numHeads;\n    this.hiddenDim = args.hiddenDim;\n    this.intermediateDim = args.intermediateDim;\n    this.dropout = args.dropout ?? 0.1;\n    this.maxSequenceLength = args.maxSequenceLength ?? 1024;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {\n      vocabularySize: this.vocabularySize,\n      numLayers: this.numLayers,\n      numHeads: this.numHeads,\n      hiddenDim: this.hiddenDim,\n      intermediateDim: this.intermediateDim,\n      dropout: this.dropout,\n      maxSequenceLength: this.maxSequenceLength,\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n\n  override get tokenEmbedding(): Embedding {\n    return this.getLayer('token_embedding') as Embedding;\n  }\n}\nserialization.registerClass(GPT2Backbone);\n"]} | ||
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"gpt2_backbone.js","sourceRoot":"","sources":["../../../../../../../../../tfjs-layers/src/layers/nlp/models/gpt2/gpt2_backbone.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,6DAA6D;AAC7D,OAAO,EAAE,aAAa,EAAE,MAAM,uBAAuB,CAAC;AAEtD,OAAO,EAAE,YAAY,EAAE,MAAM,0BAA0B,CAAC;AACxD,OAAO,EAAE,KAAK,EAAE,MAAM,qBAAqB,CAAC;AAC5C,OAAO,EAAE,SAAS,EAAE,MAAM,qBAAqB,CAAC;AAEhD,OAAO,EAAE,iBAAiB,EAAE,MAAM,mCAAmC,CAAC;AACtE,OAAO,EAAE,GAAG,EAAE,MAAM,4BAA4B,CAAC;AACjD,OAAO,EAAE,OAAO,EAAE,MAAM,eAAe,CAAC;AACxC,OAAO,EAAE,kBAAkB,EAAE,MAAM,oCAAoC,CAAC;AACxE,OAAO,EAAE,aAAa,EAAE,MAAM,yBAAyB,CAAC;AACxD,OAAO,EAAE,kBAAkB,EAAE,MAAM,wBAAwB,CAAC;AAC5D,OAAO,EAAE,QAAQ,EAAE,MAAM,aAAa,CAAC;AAEvC,SAAS,qBAAqB,CAAC,MAAM,GAAG,IAAI;IAC1C,OAAO,IAAI,YAAY,CAAC,EAAC,MAAM,EAAC,CAAC,CAAC;AACpC,CAAC;AA6CD;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAwCG;AACH,MAAa,YAAa,SAAQ,QAAQ;IAYxC,YAAY,IAAsB;;QAChC,IAAI,CAAC,OAAO,GAAG,MAAA,IAAI,CAAC,OAAO,mCAAI,GAAG,CAAC;QACnC,IAAI,CAAC,iBAAiB,GAAG,MAAA,IAAI,CAAC,iBAAiB,mCAAI,IAAI,CAAC;QAExD,SAAS;QACT,MAAM,QAAQ,GAAG,KAAK,CAAC,EAAC,KAAK,EAAE,CAAC,IAAI,CAAC,EAAE,KAAK,EAAE,OAAO,EAAE,IAAI,EAAE,WAAW,EAAC,CAAC,CAAC;QAC3E,MAAM,WAAW,GACf,KAAK,CAAC,EAAC,KAAK,EAAE,CAAC,IAAI,CAAC,EAAE,KAAK,EAAE,OAAO,EAAE,IAAI,EAAE,cAAc,EAAC,CAAC,CAAC;QAE/D,2BAA2B;QAC3B,MAAM,cAAc,GAAG,IAAI,SAAS,CAAC;YACnC,QAAQ,EAAE,IAAI,CAAC,cAAc;YAC7B,SAAS,EAAE,IAAI,CAAC,SAAS;YACzB,qBAAqB,EAAE,qBAAqB,CAAC,IAAI,CAAC;YAClD,IAAI,EAAE,iBAAiB;SACxB,CAAC,CAAC,KAAK,CAAC,QAAQ,CAAmB,CAAC;QAErC,MAAM,iBAAiB,GAAG,IAAI,iBAAiB,CAAC;YAC9C,WAAW,EAAE,qBAAqB,CAAC,IAAI,CAAC;YACxC,cAAc,EAAE,IAAI,CAAC,iBAAiB;YACtC,IAAI,EAAE,oBAAoB;SAC3B,CAAC,CAAC,KAAK,CAAC,cAAc,CAAmB,CAAC;QAE3C,uCAAuC;QACvC,IAAI,CAAC,GAAG,GAAG,CAAC,EAAC,IAAI,EAAE,gBAAgB,EAAC,CAAC;aAClC,KAAK,CAAC,CAAC,cAAc,EAAE,iBAAiB,CAAC,CAAmB,CAAC;QAChE,CAAC,GAAG,IAAI,OAAO,CAAC,EAAC,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,IAAI,EAAE,oBAAoB,EAAC,CAAC;aAC9D,KAAK,CAAC,CAAC,CAAmB,CAAC;QAE9B,+CAA+C;QAC/C,KAAI,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,SAAS,EAAE,CAAC,EAAE,EAAE;YACtC,CAAC,GAAG,IAAI,kBAAkB,CAAC;gBACzB,eAAe,EAAE,IAAI,CAAC,eAAe;gBACrC,QAAQ,EAAE,IAAI,CAAC,QAAQ;gBACvB,OAAO,EAAE,IAAI,CAAC,OAAO;gBACrB,gBAAgB,EAAE,KAAK;gBACvB,UAAU,EAAE,aAAa,CAAC,MAAM,CAAC;gBACjC,iBAAiB,EAAE,qBAAqB,CAAC,IAAI,CAAC;gBAC9C,cAAc,EAAE,IAAI;gBACpB,IAAI,EAAE,qBAAqB,CAAC,EAAE;aAC/B,CAAC,CAAC,KAAK,CAAC,CAAC,EAAE,EAAC,kBAAkB,EAAE,WAAW,EAAC,CAAmB,CAAC;SAClE;QAED,MAAM,cAAc,GAAG,IAAI,kBAAkB,CAAC;YAC5C,IAAI,EAAE,YAAY;YAClB,IAAI,EAAE,CAAC,CAAC;YACR,OAAO,EAAE,KAAK;YACd,KAAK,EAAE,SAAS;SACjB,CAAC,CAAC,KAAK,CAAC,CAAC,CAAmB,CAAC;QAE9B,sDAAsD;QACtD,KAAK,CAAC;YACJ,MAAM,EAAE,CAAC,QAAQ,EAAE,WAAW,CAAC;YAC/B,OAAO,EAAE,cAAc;YACvB,IAAI,EAAE,eAAe;SACtB,CAAC,CAAC;QACH,IAAI,CAAC,cAAc,GAAG,IAAI,CAAC,cAAc,CAAC;QAC1C,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,SAAS,CAAC;QAChC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAC9B,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,SAAS,CAAC;QAChC,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC,eAAe,CAAC;QAC5C,IAAI,CAAC,OAAO,GAAG,MAAA,IAAI,CAAC,OAAO,mCAAI,GAAG,CAAC;QACnC,IAAI,CAAC,iBAAiB,GAAG,MAAA,IAAI,CAAC,iBAAiB,mCAAI,IAAI,CAAC;IAC1D,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAA6B;YACvC,cAAc,EAAE,IAAI,CAAC,cAAc;YACnC,SAAS,EAAE,IAAI,CAAC,SAAS;YACzB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,SAAS,EAAE,IAAI,CAAC,SAAS;YACzB,eAAe,EAAE,IAAI,CAAC,eAAe;YACrC,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;SAC1C,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;IAED,IAAa,cAAc;QACzB,OAAO,IAAI,CAAC,QAAQ,CAAC,iBAAiB,CAAc,CAAC;IACvD,CAAC;;AA7FD,kBAAkB;AACF,sBAAS,GAAG,cAAc,CAAC;SAFhC,YAAY;AAgGzB,aAAa,CAAC,aAAa,CAAC,YAAY,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Base class for Backbone models.\n */\n\n/* Original source: keras_nlp/models/gpt2/gpt2_backbone.py */\nimport { serialization } from '@tensorflow/tfjs-core';\n\nimport { RandomNormal } from '../../../../initializers';\nimport { input } from '../../../../exports';\nimport { Embedding } from '../../../embeddings';\nimport { SymbolicTensor } from '../../../../engine/topology';\nimport { PositionEmbedding } from '../../modeling/position_embedding';\nimport { add } from '../../../../exports_layers';\nimport { Dropout } from '../../../core';\nimport { TransformerDecoder } from '../../modeling/transformer_decoder';\nimport { getActivation } from '../../../../activations';\nimport { LayerNormalization } from '../../../normalization';\nimport { Backbone } from '../backbone';\n\nfunction gpt2KernelInitializer(stddev = 0.02) {\n  return new RandomNormal({stddev});\n}\n\nexport interface GPT2BackboneArgs  {\n  /**\n   * Integer. The size of the token vocabulary.\n   */\n  vocabularySize: number;\n\n  /**\n   * Integer. The number of transformer layers.\n   */\n  numLayers: number;\n\n  /**\n   * Integer. The number of attention heads for each transformer.\n   * The hidden size must be divisible by the number of attention heads.\n   */\n  numHeads: number;\n\n  /**\n   * Integer. The size of the transformer encoding and pooler layers.\n   */\n  hiddenDim: number;\n\n  /**\n   * Integer. The output dimension of the first Dense layer in a two-layer\n   * feedforward network for each transformer.\n   */\n  intermediateDim: number;\n\n  /**\n   * Float. Dropout probability for the Transformer encoder.\n   * Defaults to 0.2.\n   */\n  dropout?: number;\n\n  /**\n   * Integer. The maximum sequence length that this encoder can consume.\n   * If `null`, `maxSequenceLength` uses the value from sequence length.\n   * This determines the variable shape for positional embeddings.\n   * Defaults to 1024.\n   */\n  maxSequenceLength?: number;\n}\n\n/**\n * GPT-2 core network with hyperparameters.\n *\n * This network implements a Transformer-based decoder network,\n * Generative Pretrained Transformer-2 (GPT-2), as described in\n * [\"Language Models are Unsupervised Multitask Learners\"](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf).\n * It includes the embedding lookups and transformer layers.\n *\n * The default constructor gives a fully customizable, randomly initialized\n * GPT-2 model with any number of layers, heads, and embedding\n * dimensions. To load preset architectures and weights, use the `fromPreset`\n * constructor.\n *\n * Disclaimer: Pre-trained models are provided on an \"as is\" basis, without\n * warranties or conditions of any kind. The underlying model is provided by a\n * third party and subject to a separate license, available\n * [here](https://github.com/openai/gpt-2).\n *\n *\n * Example usage:\n * ```js\n * const tokenIds = tf.ones([1, 12]), dtype=\"int32\");\n * const paddingMask = tf.tensor(\n *  [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], 'int32');\n *\n * # Pretrained GPT-2 decoder.\n * model = GPT2Backbone.fromPreset(\"gpt2_base_en\");\n * model.apply(inputData, {paddingMask});\n *\n * # Randomly initialized GPT-2 decoder with custom config.\n * model = kerasNlp.models.GPT2Backbone({\n *     vocabularySize: 50257,\n *     numLayers: 12,\n *     numHeads: 12,\n *     hiddenDim: 768,\n *     intermediateDim: 3072,\n *     maxSequenceLength: 1024,\n * });\n * model.apply(inputData, {paddingMask});\n * ```\n */\nexport class GPT2Backbone extends Backbone {\n  /** @nocollapse */\n  static override className = 'GPT2Backbone';\n\n  private vocabularySize: number;\n  private numLayers: number;\n  private numHeads: number;\n  private hiddenDim: number;\n  private intermediateDim: number;\n  private dropout: number;\n  private maxSequenceLength: number;\n\n  constructor(args: GPT2BackboneArgs) {\n    args.dropout = args.dropout ?? 0.1;\n    args.maxSequenceLength = args.maxSequenceLength ?? 1024;\n\n    // Inputs\n    const tokenIds = input({shape: [null], dtype: 'int32', name: 'token_ids'});\n    const paddingMask =\n      input({shape: [null], dtype: 'int32', name: 'padding_mask'});\n\n    // Embed tokens, positions.\n    const tokenEmbedding = new Embedding({\n      inputDim: args.vocabularySize,\n      outputDim: args.hiddenDim,\n      embeddingsInitializer: gpt2KernelInitializer(0.01),\n      name: 'token_embedding',\n    }).apply(tokenIds) as SymbolicTensor;\n\n    const positionEmbedding = new PositionEmbedding({\n      initializer: gpt2KernelInitializer(0.02),\n      sequenceLength: args.maxSequenceLength,\n      name: 'position_embedding',\n    }).apply(tokenEmbedding) as SymbolicTensor;\n\n    // Sum and apply dropout to embeddings.\n    let x = add({name: 'embeddings_add'})\n      .apply([tokenEmbedding, positionEmbedding]) as SymbolicTensor;\n    x = new Dropout({rate: args.dropout, name: 'embeddings_dropout'})\n      .apply(x) as SymbolicTensor;\n\n    // Apply successive transformer decoder blocks.\n    for(let i = 0; i < args.numLayers; i++) {\n      x = new TransformerDecoder({\n        intermediateDim: args.intermediateDim,\n        numHeads: args.numHeads,\n        dropout: args.dropout,\n        layerNormEpsilon: 1e-05,\n        activation: getActivation('gelu'),\n        kernelInitializer: gpt2KernelInitializer(0.02),\n        normalizeFirst: true,\n        name: `transformer_layer_${i}`,\n      }).apply(x, {decoderPaddingMask: paddingMask}) as SymbolicTensor;\n    }\n\n    const sequenceOutput = new LayerNormalization({\n      name: 'layer_norm',\n      axis: -1,\n      epsilon: 1e-05,\n      dtype: 'float32',\n    }).apply(x) as SymbolicTensor;\n\n    // Instantiate using Functional API Model constructor.\n    super({\n      inputs: [tokenIds, paddingMask],\n      outputs: sequenceOutput,\n      name: 'gpt2_backbone'\n    });\n    this.vocabularySize = args.vocabularySize;\n    this.numLayers = args.numLayers;\n    this.numHeads = args.numHeads;\n    this.hiddenDim = args.hiddenDim;\n    this.intermediateDim = args.intermediateDim;\n    this.dropout = args.dropout ?? 0.1;\n    this.maxSequenceLength = args.maxSequenceLength ?? 1024;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config: serialization.ConfigDict = {\n      vocabularySize: this.vocabularySize,\n      numLayers: this.numLayers,\n      numHeads: this.numHeads,\n      hiddenDim: this.hiddenDim,\n      intermediateDim: this.intermediateDim,\n      dropout: this.dropout,\n      maxSequenceLength: this.maxSequenceLength,\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n\n  override get tokenEmbedding(): Embedding {\n    return this.getLayer('token_embedding') as Embedding;\n  }\n}\nserialization.registerClass(GPT2Backbone);\n"]} |
/** @license See the LICENSE file. */ | ||
/// <amd-module name="@tensorflow/tfjs-layers/dist/version" /> | ||
declare const version = "4.18.0"; | ||
declare const version = "4.19.0-rc.0"; | ||
export { version }; |
/** @license See the LICENSE file. */ | ||
// This code is auto-generated, do not modify this file! | ||
const version = '4.18.0'; | ||
const version = '4.19.0-rc.0'; | ||
export { version }; | ||
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoidmVyc2lvbi5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uL3RmanMtbGF5ZXJzL3NyYy92ZXJzaW9uLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBLHFDQUFxQztBQUVyQyx3REFBd0Q7QUFDeEQsTUFBTSxPQUFPLEdBQUcsUUFBUSxDQUFDO0FBQ3pCLE9BQU8sRUFBQyxPQUFPLEVBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKiBAbGljZW5zZSBTZWUgdGhlIExJQ0VOU0UgZmlsZS4gKi9cblxuLy8gVGhpcyBjb2RlIGlzIGF1dG8tZ2VuZXJhdGVkLCBkbyBub3QgbW9kaWZ5IHRoaXMgZmlsZSFcbmNvbnN0IHZlcnNpb24gPSAnNC4xOC4wJztcbmV4cG9ydCB7dmVyc2lvbn07XG4iXX0= | ||
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoidmVyc2lvbi5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uL3RmanMtbGF5ZXJzL3NyYy92ZXJzaW9uLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBLHFDQUFxQztBQUVyQyx3REFBd0Q7QUFDeEQsTUFBTSxPQUFPLEdBQUcsYUFBYSxDQUFDO0FBQzlCLE9BQU8sRUFBQyxPQUFPLEVBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKiBAbGljZW5zZSBTZWUgdGhlIExJQ0VOU0UgZmlsZS4gKi9cblxuLy8gVGhpcyBjb2RlIGlzIGF1dG8tZ2VuZXJhdGVkLCBkbyBub3QgbW9kaWZ5IHRoaXMgZmlsZSFcbmNvbnN0IHZlcnNpb24gPSAnNC4xOS4wLXJjLjAnO1xuZXhwb3J0IHt2ZXJzaW9ufTtcbiJdfQ== |
{ | ||
"name": "@tensorflow/tfjs-layers", | ||
"version": "4.18.0", | ||
"version": "4.19.0-rc.0", | ||
"description": "TensorFlow layers API in JavaScript", | ||
@@ -41,4 +41,4 @@ "license": "Apache-2.0 AND MIT", | ||
"peerDependencies": { | ||
"@tensorflow/tfjs-core": "4.18.0" | ||
"@tensorflow/tfjs-core": "4.19.0-rc.0" | ||
} | ||
} |
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
No v1
QualityPackage is not semver >=1. This means it is not stable and does not support ^ ranges.
Found 1 instance in 1 package
Unidentified License
License(Experimental) Something that seems like a license was found, but its contents could not be matched with a known license.
Found 3 instances in 1 package
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
Unidentified License
License(Experimental) Something that seems like a license was found, but its contents could not be matched with a known license.
Found 3 instances in 1 package
31060693
184320
5