a2
Advanced tools
| import logging as logging_std_library | ||
| import a2.training.benchmarks as timer | ||
| import a2.utils.utils | ||
| import transformers | ||
| from transformers.trainer import * # Ugly but probably needed ... # noeq | ||
| torch = a2.utils.utils._import_torch(__file__) | ||
| class TimerCallback(transformers.TrainerCallback): | ||
| """ | ||
| # Example usage: | ||
| tmr = timer.CPUGPUTimer() | ||
| trainer = transformers.Trainer( | ||
| # Other args... | ||
| callbacks=[TimerCallback(tmr, gpu=True)] | ||
| ) | ||
| trainer.train() | ||
| tmr.print_all_time_stats() | ||
| """ | ||
| def __init__(self, timer, gpu=False): | ||
| super().__init__() | ||
| self.timer = timer | ||
| self.gpu = gpu | ||
| def on_epoch_begin(self, args, state, control, **kwargs): | ||
| self.timer.start(timer.TimeType.EPOCH) | ||
| def on_epoch_end(self, args, state, control, **kwargs): | ||
| self.timer.end(timer.TimeType.EPOCH) | ||
| self.timer.complete_all() | ||
| def on_step_begin(self, args, state, control, **kwargs): | ||
| logging_std_library.info( | ||
| f"Epoch {int(state.epoch)}: Start iteration step {state.global_step}/{state.max_steps} of training..." | ||
| ) | ||
| self.timer.start(timer.TimeType.BATCH) | ||
| def on_step_end(self, args, state, control, **kwargs): | ||
| self.timer.end(timer.TimeType.BATCH) | ||
| if state.global_step % 10 == 0: | ||
| self.timer.complete_all() | ||
| class TimeLoaderWrapper: | ||
| """Wrapper around a DataLoader (*not* a Dataset!) for I/O timing.""" | ||
| def __init__(self, loader, timer): | ||
| self.loader = loader | ||
| self.tmr = timer | ||
| @staticmethod | ||
| def time_loader(loader, tmr): | ||
| if len(loader) > 0: | ||
| tmr.start(timer.TimeType.IO) | ||
| for i, data in enumerate(loader): | ||
| tmr.end(timer.TimeType.IO) | ||
| if i % 10 == 0: | ||
| tmr.complete_all() | ||
| yield data | ||
| if i != len(loader) - 1: | ||
| tmr.start(timer.TimeType.IO) | ||
| def __iter__(self): | ||
| return TimeLoaderWrapper.time_loader(self.loader, self.tmr) | ||
| def __len__(self): | ||
| return len(self.loader) | ||
| def reset(self): | ||
| if hasattr(self.loader, "reset"): | ||
| self.loader.reset() | ||
| class TrainerWithTimer(Trainer): | ||
| """ | ||
| Custom Trainer subclass to support finer-grained timing. | ||
| Should also use the callback above. | ||
| Adapted from original: | ||
| # https://github.com/huggingface/transformers/blob/v4.25.1/src/transformers/trainer.py | ||
| """ | ||
| def __init__(self, *args, **kwargs): | ||
| print(f'{kwargs["callbacks"]=}') | ||
| self.timer_callback = kwargs["callbacks"][0] | ||
| self.tmr = self.timer_callback.timer | ||
| self.tmr_gpu = self.timer_callback.gpu | ||
| super().__init__(*args, **kwargs) | ||
| def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: | ||
| """ | ||
| Perform a training step on a batch of inputs. | ||
| Subclass and override to inject custom behavior. | ||
| Args: | ||
| model (`nn.Module`): | ||
| The model to train. | ||
| inputs (`Dict[str, Union[torch.Tensor, Any]]`): | ||
| The inputs and targets of the model. | ||
| The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | ||
| argument `labels`. Check your model's documentation for all accepted arguments. | ||
| Return: | ||
| `torch.Tensor`: The tensor with training loss on this batch. | ||
| """ | ||
| model.train() | ||
| inputs = self._prepare_inputs(inputs) | ||
| if is_sagemaker_mp_enabled(): | ||
| loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) | ||
| return loss_mb.reduce_mean().detach().to(self.args.device) | ||
| if self.tmr: | ||
| self.tmr.start(timer.TimeType.FORWARD) | ||
| with self.compute_loss_context_manager(): | ||
| loss = self.compute_loss(model, inputs) | ||
| if self.args.n_gpu > 1: | ||
| loss = loss.mean() # mean() to average on multi-gpu parallel training | ||
| if self.tmr: | ||
| self.tmr.end(timer.TimeType.FORWARD) | ||
| self.tmr.start(timer.TimeType.BACKWARD) | ||
| if self.use_apex: | ||
| with amp.scale_loss(loss, self.optimizer) as scaled_loss: | ||
| scaled_loss.backward() | ||
| else: | ||
| self.accelerator.backward(loss) | ||
| if self.tmr: | ||
| self.tmr.end(timer.TimeType.BACKWARD) | ||
| return loss.detach() / self.args.gradient_accumulation_steps | ||
| def get_train_dataloader(self) -> DataLoader: | ||
| dl = super().get_train_dataloader() | ||
| if self.tmr: | ||
| return TimeLoaderWrapper(dl, self.tmr) | ||
| return dl | ||
| def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): | ||
| # self.tmr.start(timer.TimeType.OTHER) | ||
| super().save_model(output_dir=output_dir, _internal_call=_internal_call) | ||
| # self.tmr.end(timer.TimeType.OTHER) |
+9
-6
| Metadata-Version: 2.1 | ||
| Name: a2 | ||
| Version: 0.10.4 | ||
| Version: 0.10.5 | ||
| Summary: Package for predicting information about the weather from social media data as application 2 for maelstrom project | ||
@@ -14,2 +14,3 @@ Author: Kristian Ehlert | ||
| Provides-Extra: deberta | ||
| Provides-Extra: deberta-tf | ||
| Provides-Extra: extend-exclude | ||
@@ -22,3 +23,4 @@ Provides-Extra: fancy-plotting | ||
| Provides-Extra: xarray-extra | ||
| Requires-Dist: bottleneck (>=1.3.7,<2.0.0) ; extra == "deberta" | ||
| Requires-Dist: accelerate (>=0.28.0,<0.29.0) | ||
| Requires-Dist: bottleneck (>=1.3.7,<2.0.0) ; extra == "deberta" or extra == "deberta-tf" | ||
| Requires-Dist: datasets (>=2.11.0,<3.0.0) | ||
@@ -34,3 +36,3 @@ Requires-Dist: ecmwflibs (==0.5.3) ; extra == "xarray-extra" | ||
| Requires-Dist: netcdf4 (>=1.6.5,<2.0.0) ; extra == "xarray-extra" | ||
| Requires-Dist: nltk (>=3.8.1,<4.0.0) ; extra == "deberta" | ||
| Requires-Dist: nltk (>=3.8.1,<4.0.0) ; extra == "deberta" or extra == "deberta-tf" | ||
| Requires-Dist: pandas (>=1.4.2,<2.0.0) | ||
@@ -43,7 +45,8 @@ Requires-Dist: plotly (>=5.11.0,<6.0.0) ; extra == "fancy-plotting" | ||
| Requires-Dist: seaborn (>=0.12.1,<0.13.0) ; extra == "fancy-plotting" | ||
| Requires-Dist: sentence-transformers (>=2.3.1,<3.0.0) ; extra == "deberta" | ||
| Requires-Dist: sentencepiece (>=0.1.98,<0.2.0) ; extra == "llama-chatbot" or extra == "deberta" | ||
| Requires-Dist: sentence-transformers (>=2.3.1,<3.0.0) ; extra == "deberta" or extra == "deberta-tf" | ||
| Requires-Dist: sentencepiece (>=0.1.98,<0.2.0) ; extra == "llama-chatbot" or extra == "deberta" or extra == "deberta-tf" | ||
| Requires-Dist: spacymoji (>=3.0.1,<4.0.0) ; extra == "tweets" | ||
| Requires-Dist: transformers (>=4.29.2,<5.0.0) ; extra == "llama-chatbot" or extra == "deberta" | ||
| Requires-Dist: tf-keras (>=2.16.0,<3.0.0) ; extra == "deberta-tf" | ||
| Requires-Dist: transformers (>=4.29.2,<5.0.0) ; extra == "llama-chatbot" or extra == "deberta" or extra == "deberta-tf" | ||
| Requires-Dist: tweepy (>=4.10.0,<5.0.0) ; extra == "tweets" | ||
| Requires-Dist: xarray[io] (>=2023.12.0,<2024.0.0) ; extra == "xarray-extra" |
+11
-1
| [tool.poetry] | ||
| name = "a2" | ||
| version = "0.10.4" | ||
| version = "0.10.5" | ||
| description = "Package for predicting information about the weather from social media data as application 2 for maelstrom project" | ||
@@ -44,2 +44,4 @@ authors = ["Kristian Ehlert <kristian.ehlert@4-cast.de>"] | ||
| netcdf4 = {version = "^1.6.5", optional = true} | ||
| tf-keras = {version = "^2.16.0", optional = true} | ||
| accelerate = {version = "^0.28.0", optional = true} | ||
@@ -111,2 +113,10 @@ [[tool.poetry.source]] | ||
| ] | ||
| deberta-tf = [ | ||
| "tf-keras", | ||
| "transformers", | ||
| "sentencepiece", | ||
| "bottleneck", | ||
| "sentence-transformers", | ||
| "nltk", | ||
| ] | ||
| tracking = ["mantik"] | ||
@@ -113,0 +123,0 @@ xarray-extra = ["xarray", "ecmwflibs", "netcdf4"] |
@@ -120,3 +120,3 @@ import dataclasses | ||
| def get_model(self, params: t.Dict, mantik: bool = True, base_model_trainable: bool = True, use_tf: bool = False): | ||
| def get_model(self, params: t.Dict, mantik: bool = True, base_model_trainable: bool = True): | ||
| db_config = self.db_config_base | ||
@@ -127,8 +127,3 @@ if params is not None: | ||
| db_config.update({"num_labels": self.num_labels}) | ||
| if use_tf: | ||
| model = transformers.TFAutoModelForSequenceClassification.from_pretrained( | ||
| self.model_folder, config=db_config | ||
| ) | ||
| else: | ||
| model = transformers.AutoModelForSequenceClassification.from_pretrained(self.model_folder, config=db_config) | ||
| model = transformers.AutoModelForSequenceClassification.from_pretrained(self.model_folder, config=db_config) | ||
| if not base_model_trainable: | ||
@@ -216,2 +211,3 @@ for param in model.base_model.parameters(): | ||
| ) | ||
| if not hyper_tuning: | ||
@@ -218,0 +214,0 @@ args = transformers.TrainingArguments( |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
754047
0.7%57
1.79%9083
1.24%