xdat
Advanced tools
+1
-1
| Metadata-Version: 2.4 | ||
| Name: xdat | ||
| Version: 0.1.295 | ||
| Version: 0.1.296 | ||
| Summary: eXtended Data Analysis Toolkit | ||
@@ -5,0 +5,0 @@ Home-page: https://bitbucket.org/hermetric/xdat/ |
| Metadata-Version: 2.4 | ||
| Name: xdat | ||
| Version: 0.1.295 | ||
| Version: 0.1.296 | ||
| Summary: eXtended Data Analysis Toolkit | ||
@@ -5,0 +5,0 @@ Home-page: https://bitbucket.org/hermetric/xdat/ |
+1
-1
@@ -1,1 +0,1 @@ | ||
| 0.1.295 | ||
| 0.1.296 |
+60
-0
@@ -1284,2 +1284,62 @@ import hashlib | ||
| class MaxLossCallback(tf.keras.callbacks.Callback): | ||
| """ | ||
| Stops training when *training* loss dips below a given threshold. | ||
| Rolls back to the previous epoch's weights (unless it's the first epoch). | ||
| """ | ||
| def __init__(self, threshold, metric="loss", verbose=1): | ||
| super().__init__() | ||
| self.threshold = float(threshold) | ||
| self.metric = metric | ||
| self.verbose = int(verbose) | ||
| self.prev_weights = None | ||
| self.prev_loss = None | ||
| self.prev_epoch = None | ||
| def on_train_begin(self, logs=None): | ||
| self.prev_weights = None | ||
| self.prev_loss = None | ||
| self.prev_epoch = None | ||
| def rollback_to_previous(self): | ||
| if self.prev_weights is not None: | ||
| self.model.set_weights(self.prev_weights) | ||
| def on_epoch_end(self, epoch, logs=None): | ||
| logs = logs or {} | ||
| current_loss = logs.get(self.metric) | ||
| if current_loss is None: | ||
| return | ||
| # Stop when loss dips below threshold | ||
| if float(current_loss) < self.threshold: | ||
| if self.prev_weights is not None: | ||
| self.rollback_to_previous() | ||
| if self.verbose: | ||
| print( | ||
| f"\nMaxLossCallback: {self.metric}={float(current_loss):.6f} " | ||
| f"< {self.threshold:.6f} at epoch={epoch}. " | ||
| f"Rolled back to epoch={self.prev_epoch} " | ||
| f"({self.metric}={self.prev_loss:.6f}) and stopping." | ||
| ) | ||
| else: | ||
| if self.verbose: | ||
| print( | ||
| f"\nMaxLossCallback: {self.metric}={float(current_loss):.6f} " | ||
| f"< {self.threshold:.6f} at epoch={epoch}. " | ||
| f"No previous epoch to rollback to; stopping." | ||
| ) | ||
| self.model.stop_training = True | ||
| return | ||
| # Store current epoch as "previous" for potential rollback | ||
| self.prev_weights = self.model.get_weights() | ||
| self.prev_loss = float(current_loss) | ||
| self.prev_epoch = epoch | ||
| @register_keras_serializable() | ||
@@ -1286,0 +1346,0 @@ def mape_loss(y_true, y_pred): |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
929913
0.23%9007
0.55%