adaptor
Advanced tools
| Metadata-Version: 2.1 | ||
| Name: adaptor | ||
| Version: 0.2.4 | ||
| Version: 0.2.5 | ||
| Summary: Adaptor: Objective-centric Adaptation Framework for Language Models. | ||
@@ -14,3 +14,3 @@ Home-page: https://github.com/gaussalgo/adaptor | ||
| Classifier: Programming Language :: Python :: 3.8 | ||
| Requires-Python: >=3.8 | ||
| Requires-Python: >=3.7 | ||
| Description-Content-Type: text/markdown | ||
@@ -22,2 +22,5 @@ License-File: LICENSE | ||
| Requires-Dist: accelerate>=0.20.1 | ||
| Requires-Dist: peft<0.13.0,>=0.10.0 | ||
| Requires-Dist: prefetch-generator>=1.0.3 | ||
| Requires-Dist: numpy<1.24 | ||
| Provides-Extra: generative | ||
@@ -24,0 +27,0 @@ Requires-Dist: sacrebleu; extra == "generative" |
@@ -5,2 +5,5 @@ torch>=1.7 | ||
| accelerate>=0.20.1 | ||
| peft<0.13.0,>=0.10.0 | ||
| prefetch-generator>=1.0.3 | ||
| numpy<1.24 | ||
@@ -7,0 +10,0 @@ [examples] |
| LICENSE | ||
| MANIFEST.in | ||
| README.md | ||
@@ -4,0 +3,0 @@ setup.cfg |
+84
-16
@@ -0,1 +1,3 @@ | ||
| import copy | ||
| import itertools | ||
| import logging | ||
@@ -5,10 +7,12 @@ import os | ||
| from transformers import WEIGHTS_NAME | ||
| from peft import PeftModel | ||
| from transformers import WEIGHTS_NAME, TrainerState | ||
| import torch | ||
| from transformers import Trainer, BatchEncoding | ||
| from transformers.modeling_utils import unwrap_model | ||
| from transformers.trainer import TRAINER_STATE_NAME | ||
| from .lang_module import LangModule | ||
| from .schedules import Schedule | ||
| from .utils import AdaptationArguments | ||
| from .utils import AdaptationArguments, SavingStrategy, PEFT_BASE_MODEL_CHECKPOINT_SUBDIR | ||
@@ -30,2 +34,3 @@ logger = logging.getLogger() | ||
| eval_metrics_prefix = "eval" | ||
| args: AdaptationArguments | ||
@@ -48,2 +53,9 @@ def __init__(self, lang_module: LangModule, schedule: Schedule, args: AdaptationArguments, **kwargs): | ||
| all_objectives_ids = list(map(str, self.schedule.objectives["train"].values())) | ||
| if len(set(all_objectives_ids)) < len(all_objectives_ids): | ||
| duplicates = [identifier for identifier in all_objectives_ids if all_objectives_ids.count(identifier) > 1] | ||
| raise ValueError("These objectives have identical identifiers: %s; This would cause " | ||
| "incorrect persistence of checkpoints for your objectives." % set(duplicates)) | ||
| lang_module.finalize() | ||
| super().__init__(model=lang_module, | ||
@@ -54,3 +66,3 @@ args=args, | ||
| data_collator=self.flattened_collator, | ||
| compute_metrics=None, # would require a static prediction format among objectives | ||
| compute_metrics=None, # logged metrics are handled by Objectives | ||
| callbacks=orig_callbacks + [schedule.should_stop_check_callback()], | ||
@@ -102,9 +114,52 @@ **kwargs) | ||
| def _save_module(self, module: torch.nn.Module, output_module_path: str) -> None: | ||
| # simple wrapper to save an arbitrary model to a directory in a standard format | ||
| # for each objective, we also persist a shared tokenizer to make each Objective independently loadable | ||
| self.model.tokenizer.save_pretrained(output_module_path) | ||
| if hasattr(module, "save_pretrained") or hasattr(unwrap_model(module), "save_pretrained"): | ||
| # if the head module has "save_pretrained" method, it will be called for persistence | ||
| module.save_pretrained(output_module_path, use_diff=False, safe_serialization=False) | ||
| else: | ||
| # otherwise, we persist only a raw pytorch module | ||
| torch.save(module.state_dict(), os.path.join(output_module_path, WEIGHTS_NAME)) | ||
| def save_model(self, output_dir: Optional[str] = None, **kwargs) -> None: | ||
| # HF native reload compatibility | ||
| objectives_counter = {str(obj): 0 for obj in self.schedule.objectives["train"].values()} | ||
| all_objectives = set(itertools.chain(self.schedule.objectives["train"].values(), | ||
| self.schedule.objectives["eval"].values())) | ||
| for objective_id in self.schedule.objectives["train"].keys(): | ||
| module = self.model.trainable_models[str(objective_id)] | ||
| objective = self.schedule.objectives["train"][int(objective_id)] | ||
| objectives_counter = {str(obj): 0 for obj in all_objectives} | ||
| os.makedirs(output_dir, exist_ok=True) | ||
| # also save the base model, if any of our objectives are peft models | ||
| if (self.args.save_peft_base_model and any( | ||
| isinstance(o.compatible_head_model, PeftModel) for o in self.schedule.objectives["train"].values())): | ||
| # For simplicity, we assume that base models for all pefts are the same | ||
| # -- this might be violated only if the user passes custom model_head to Objective | ||
| # and additionally creates a peft module on it. | ||
| # With this assumption, we retrieve a base model from an arbitrary (i.e. the first) peft-model objective | ||
| peft_obj = next(o for o in self.schedule.objectives["train"].values() | ||
| if isinstance(o.compatible_head_model, PeftModel)) | ||
| orig_model = copy.deepcopy(peft_obj.compatible_head_model) | ||
| while isinstance(orig_model, PeftModel): | ||
| # we find cases where unload() does not return the base model on the first call | ||
| orig_model = orig_model.unload() | ||
| base_model_path = os.path.join(output_dir, PEFT_BASE_MODEL_CHECKPOINT_SUBDIR) | ||
| self._save_module(orig_model, base_model_path) | ||
| logger.info(f"Base model for PEFT objectives saved in {base_model_path}") | ||
| for objective in all_objectives: | ||
| if not objective.save_objective_module: | ||
| logger.warning("Skipping objective %s from saving objectives' modules.", objective) | ||
| continue | ||
| module = objective.compatible_head_model | ||
| if (self.args.saving_strategy == SavingStrategy.FINISHED_OBJECTIVES | ||
| and self.objective not in self.schedule.converged_objectives): | ||
| logger.warning("Not saving model for %s as SavingStrategy is set to FINISHED_OBJECTIVES.", objective) | ||
| continue | ||
| output_module_path = os.path.join(output_dir, str(objective)) | ||
@@ -117,13 +172,26 @@ | ||
| # we persist a shared tokenizer and training args either way | ||
| self.model.tokenizer.save_pretrained(output_module_path) | ||
| # training args are shared and persisted in the output_dir root | ||
| torch.save(self.args, os.path.join(output_dir, "training_args.bin")) | ||
| if isinstance(module, PeftModel) and self.args.save_peft_base_model: | ||
| base_model_path = os.path.abspath(os.path.join(output_dir, "base_model")) | ||
| module.peft_config['default'].base_model_name_or_path = base_model_path | ||
| logger.warning("Base model for PEFT objective %s set to %s", objective, base_model_path) | ||
| if hasattr(module, "save_pretrained") or hasattr(unwrap_model(module), "save_pretrained"): | ||
| # if the head module has "save_pretrained" method, it will be called for persistence | ||
| module.save_pretrained(output_module_path, use_diff=True) | ||
| else: | ||
| # otherwise, we persist only a raw pytorch module | ||
| torch.save(module.state_dict(), os.path.join(output_module_path, WEIGHTS_NAME)) | ||
| self._save_module(module, output_module_path) | ||
| logger.warning(f"Model of objective {str(objective)} saved in {output_module_path}") | ||
| if self.args.saving_strategy == SavingStrategy.FIRST_OBJECTIVE: | ||
| logger.warning("Skipping other objectives from saving as the chosen SavingStrategy is FIRST_OBJECTIVE.") | ||
| break | ||
| logger.info(f"Model of objective {str(objective)} saved in {output_module_path}") | ||
| def _load_optimizer_and_scheduler(self, checkpoint: str) -> None: | ||
| # Customizations to support continued training | ||
| # If the training already State exists, it overrides newly-initialized TrainerState (initialized in HF.train()) | ||
| possible_state_path = os.path.join(self.model.model_name_or_path, TRAINER_STATE_NAME) | ||
| if os.path.exists(possible_state_path): | ||
| self.state = TrainerState.load_from_json(possible_state_path) | ||
| logger.warning("Restoring training on global step %s", self.state.global_step) | ||
| # in case of continued training, optimizer exists on model.model_name_or_path | ||
| # if the optimizer.pt does not exist, the `super()._load_optimizer_and_scheduler` does not do anything | ||
| return super()._load_optimizer_and_scheduler(checkpoint=self.model.model_name_or_path) |
| import abc | ||
| import logging | ||
| from functools import lru_cache | ||
| from typing import List, Sequence, Optional, Dict, Iterator, Union | ||
| from typing import List, Sequence, Optional, Dict, Union, Any, Tuple | ||
@@ -10,9 +11,10 @@ import numpy as np | ||
| from sacrebleu import corpus_bleu | ||
| from transformers import PreTrainedTokenizer, BatchEncoding, MBart50Tokenizer, MBartTokenizer | ||
| from transformers import PreTrainedTokenizer, BatchEncoding | ||
| from .evaluator_base import EvaluatorBase | ||
| from .prism import Prism | ||
| from ..utils import Head, AdaptationDataset | ||
| logger = logging.getLogger() | ||
| class GenerativeEvaluator(EvaluatorBase, abc.ABC): | ||
@@ -29,2 +31,3 @@ """ | ||
| compatible_heads: List[Head] = [Head.SEQ2SEQ] | ||
| generation_kwargs: Tuple[Tuple[str, Any]] # to use lru_cache, the args must be hashable | ||
@@ -35,3 +38,4 @@ def __init__(self, | ||
| decides_convergence: Optional[bool] = False, | ||
| additional_sep_char: Optional[str] = None): | ||
| additional_sep_char: Optional[str] = None, | ||
| generation_kwargs: Dict[str, Any] = {}): | ||
| super().__init__(decides_convergence) | ||
@@ -42,2 +46,3 @@ | ||
| self.progress_bar = progress_bar | ||
| self.generation_kwargs = tuple(generation_kwargs.items()) | ||
@@ -49,3 +54,3 @@ @staticmethod | ||
| model: torch.nn.Module, | ||
| tokenizer: PreTrainedTokenizer) -> torch.LongTensor: | ||
| generation_kwargs: Tuple[Tuple[str, Any]]) -> torch.LongTensor: | ||
| """ | ||
@@ -55,12 +60,4 @@ Performs a generation for a single input batch. The results are meant to be cached, | ||
| """ | ||
| if isinstance(tokenizer, MBart50Tokenizer): | ||
| # Forced BOS token for MBart50 | ||
| return model.generate(input_ids=input_ids, attention_mask=attention_mask, | ||
| forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]).detach().cpu() | ||
| elif isinstance(tokenizer, MBartTokenizer): | ||
| # Forced BOS token for MBart | ||
| return model.generate(input_ids=input_ids, attention_mask=attention_mask, | ||
| decoder_start_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]).detach().cpu() | ||
| else: | ||
| return model.generate(input_ids=input_ids, attention_mask=attention_mask).detach().cpu() | ||
| kwargs = dict(generation_kwargs) | ||
| return model.generate(input_ids=input_ids, attention_mask=attention_mask, **kwargs).detach().cpu() | ||
@@ -70,3 +67,3 @@ def _autoregressive_predict(self, | ||
| inputs_batch: Dict[str, torch.LongTensor], | ||
| tokenizer: PreTrainedTokenizer) -> Iterator[torch.LongTensor]: | ||
| tokenizer: PreTrainedTokenizer) -> torch.LongTensor: | ||
| """ | ||
@@ -78,8 +75,4 @@ Performs an iterative generation using the default configuration of the model's `generate` method. | ||
| """ | ||
| assert hasattr(model, "generate"), "If Evaluator(use_generate=True), " \ | ||
| "evaluated model must have its generate() method." | ||
| return self._autoregressive_predict_one(inputs_batch["input_ids"], inputs_batch["attention_mask"], | ||
| model, tokenizer) | ||
| model, self.generation_kwargs) | ||
@@ -106,2 +99,10 @@ @staticmethod | ||
| if self.use_generate: | ||
| assert hasattr(model, "generate"), "If Evaluator(use_generate=True), " \ | ||
| "the evaluated model must have implement a generate() method." | ||
| if hasattr(tokenizer, "lang_code_to_id") and not any("token_id" in k for k, v in self.generation_kwargs): | ||
| logger.warning("Your tokenizer has a `lang_code_to_id` attribute, but no `*token_id` was used in " | ||
| "generation_kwargs. Be sure to check the model docs on how to generate with this model.") | ||
| for batch in dataset: | ||
@@ -199,2 +200,3 @@ with torch.no_grad(): | ||
| **kwargs): | ||
| from .prism import Prism | ||
| # language must be set, see prism.py: MODELS['langs'] for a list of supported langs | ||
@@ -201,0 +203,0 @@ super().__init__(**kwargs) |
+150
-42
@@ -0,11 +1,12 @@ | ||
| import inspect | ||
| import logging | ||
| import inspect | ||
| import os | ||
| from copy import deepcopy | ||
| from typing import List, Dict, Any, Optional | ||
| import torch | ||
| from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModelForSequenceClassification, \ | ||
| AutoModelForTokenClassification, AutoModelForSeq2SeqLM, AutoModelForCausalLM, \ | ||
| AutoModelForMaskedLM, AutoModelForQuestionAnswering | ||
| from peft import PeftConfig, get_peft_model | ||
| from transformers import PreTrainedTokenizer, AutoTokenizer | ||
| from .utils import Head | ||
| from .utils import Head, HEAD_TO_MODEL_CLS | ||
@@ -27,3 +28,5 @@ logger = logging.getLogger() | ||
| tokenizer: PreTrainedTokenizer | ||
| model_name_or_path: str | ||
| trainable_models: torch.nn.ModuleDict | ||
| peft_base_model: Optional[torch.nn.Module] | ||
| heads_output_sizes: Dict[str, int] = {} | ||
@@ -34,3 +37,3 @@ | ||
| self.model_name_or_path = model_name_or_path | ||
| self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | ||
| self.tokenizer = self._find_and_load_tokenizer(model_name_or_path) | ||
@@ -40,6 +43,46 @@ # head_kwargs = head_kwargs if head_kwargs is not None else [{}] * len(head_types) | ||
| self.trainable_models = torch.nn.ModuleDict() | ||
| self.peft_base_model = None | ||
| @staticmethod | ||
| def load_head(model_name_or_path: str, | ||
| def _find_and_load_tokenizer(model_name_or_path) -> PreTrainedTokenizer: | ||
| try: | ||
| # New training | ||
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | ||
| logger.info("Loaded tokenizer from %s", model_name_or_path) | ||
| except OSError: | ||
| # Continued training | ||
| # in Adaptor checkpoints, tokenizers are persisted in the respective objectives' subdirs | ||
| # Hence, here we also look for the tokenizer in the model_name_or_path's subdirs | ||
| root = model_name_or_path | ||
| # continued training | ||
| subdirs = [path for path in os.listdir(root) | ||
| if os.path.isdir(os.path.join(root, path))] | ||
| subdirs_with_tokenizer = [os.path.join(root, subdir) for subdir in subdirs | ||
| if any(f.startswith("tokenizer") for f in os.listdir(os.path.join(root, subdir)))] | ||
| if not subdirs_with_tokenizer: | ||
| raise OSError("Could not find a tokenizer in any of the subdirectories %s " | ||
| "of given model_name_or_path='%s'" % (subdirs, root)) | ||
| tokenizer_dir = subdirs_with_tokenizer[0] | ||
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) | ||
| logger.info("Loaded tokenizer from %s", tokenizer_dir) | ||
| return tokenizer | ||
| @staticmethod | ||
| def _set_peft_trainable_params(model: torch.nn.Module, trainable_model: torch.nn.Module) -> None: | ||
| other_model_params = {n: v for n, v in trainable_model.named_parameters()} | ||
| trainable_params = {n for n, v in trainable_model.named_parameters() if v.requires_grad} | ||
| trainables_count = 0 | ||
| for name, param in model.named_parameters(): | ||
| assert name in other_model_params, "Trying to initialize PEFT modules non-identical base models" | ||
| if name in trainable_params: | ||
| param.requires_grad = True | ||
| trainables_count += 1 | ||
| logger.warning("Set %s parameter tensors for a new head as trainable", trainables_count) | ||
| def load_head(self, | ||
| model_name_or_path: str, | ||
| head_type: Head, | ||
| load_as_peft: bool, | ||
| head_kwargs: Dict[str, Any]) -> torch.nn.Module: | ||
@@ -50,20 +93,57 @@ """ | ||
| :param head_type: type of the requested head | ||
| :param load_as_peft: whether to load the new head as PEFT module or a standard transformers model | ||
| :param head_kwargs: additional initialization arguments, adjusting its default, or persisted config | ||
| :return: transformer with a gead of requested type | ||
| :return: transformer with a head of requested type or a new pytorch model | ||
| """ | ||
| try: | ||
| # trying to load first as a transformer model, and if it fails, as a peft model | ||
| BaseModelCls = HEAD_TO_MODEL_CLS[head_type]["full"] | ||
| if not load_as_peft: | ||
| new_head = BaseModelCls.from_pretrained(model_name_or_path, **head_kwargs) | ||
| else: | ||
| logger.warning("Loading model_name_or_path='%s' as peft model.", model_name_or_path) | ||
| PeftModelCls = HEAD_TO_MODEL_CLS[head_type]["peft"] | ||
| if head_type == Head.SEQ_CLASSIFICATION: | ||
| new_head = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, **head_kwargs) | ||
| elif head_type == Head.TOKEN_CLASSIFICATION: | ||
| new_head = AutoModelForTokenClassification.from_pretrained(model_name_or_path, **head_kwargs) | ||
| elif head_type == Head.SEQ2SEQ: | ||
| new_head = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, **head_kwargs) | ||
| elif head_type == Head.CLM: | ||
| new_head = AutoModelForCausalLM.from_pretrained(model_name_or_path, **head_kwargs) | ||
| elif head_type == Head.MLM: | ||
| new_head = AutoModelForMaskedLM.from_pretrained(model_name_or_path, **head_kwargs) | ||
| elif head_type == Head.QA: | ||
| new_head = AutoModelForQuestionAnswering.from_pretrained(model_name_or_path, **head_kwargs) | ||
| else: | ||
| # Rule of thumb: PEFT modules keep track of their own base model | ||
| # In continued training with PEFT, the PeftConfig must be persisted | ||
| try: | ||
| # try loading as an existing PEFT model (=> it already has its PeftConfig) | ||
| peft_model_config = PeftConfig.from_pretrained(model_name_or_path) | ||
| if self.peft_base_model is None: | ||
| # we avoid reloading the base model separately for each lora module | ||
| self.peft_base_model = BaseModelCls.from_pretrained(peft_model_config.base_model_name_or_path) | ||
| new_head = PeftModelCls.from_pretrained(deepcopy(self.peft_base_model), model_name_or_path, | ||
| **head_kwargs) | ||
| # new_peft_model is used to find trainable parameters in continued training/reloading-PEFT case | ||
| new_peft_model = get_peft_model(deepcopy(self.peft_base_model), **head_kwargs) | ||
| self._set_peft_trainable_params(new_head, new_peft_model) | ||
| logger.warning("Reloaded existing PEFT module from %s with base model %s.", | ||
| model_name_or_path, peft_model_config.base_model_name_or_path) | ||
| except ValueError: | ||
| # if loading as existing PEFT is not possible, fall back to loading a brand new PEFT model | ||
| logger.warning("Initializing a new PEFT module.") | ||
| # ValueError: Can't find 'adapter_config.json' at {model_name_or_path} | ||
| # -> we initialize a new PEFT model from a full pre-trained transformer (simplest case) | ||
| assert 'peft_config' in head_kwargs, \ | ||
| ("Initializing an objective with PEFT model requires to pass a 'peft_config' " | ||
| "within `objective_args_for_head_config`, e.g: " | ||
| "`objective = Objective(objective_args_for_head_config={'peft_config': LoraConfig()}`." | ||
| " See the docs on https://huggingface.co/docs/peft/main/en/package_reference/config") | ||
| if self.peft_base_model is None: | ||
| self.peft_base_model = BaseModelCls.from_pretrained(model_name_or_path) | ||
| head_kwargs['peft_config'].base_model_name_or_path = model_name_or_path | ||
| # note that in practice, PEFT model initialisation is never called twice! | ||
| new_head = get_peft_model(deepcopy(self.peft_base_model), **head_kwargs) | ||
| except KeyError: | ||
| # requested head type is not in our map | ||
| logger.warning("Model in %s is not a transformers model. " | ||
| "Trying to load as a Pytorch model." % model_name_or_path) | ||
| new_head = torch.load(model_name_or_path, **head_kwargs) | ||
| except ValueError as e: | ||
| # model type is recognized, but could not be loaded | ||
| raise ValueError("Could not load model from %s as a transformer or peft model." % model_name_or_path) \ | ||
| from e | ||
@@ -74,5 +154,8 @@ return new_head | ||
| head_type: Head, | ||
| load_as_peft: bool, | ||
| objective_id: str, | ||
| checkpoint_dir: Optional[str] = None, | ||
| head_kwargs: Optional[Dict[str, Any]] = None, | ||
| new_head: Optional[torch.nn.Module] = None) -> torch.nn.Module: | ||
| new_head: Optional[torch.nn.Module] = None, | ||
| do_merge: bool = True) -> torch.nn.Module: | ||
| """ | ||
@@ -82,8 +165,12 @@ Registers a selected model to this LangModule, i.e. merges its weights with first one of self.trainable_models, | ||
| :param head_type: if no `new_head` is given, a transformer of self.model_name_or_path | ||
| :param load_as_peft: whether to load the head for the new objective as PEFT (e.g. LoRA) module | ||
| with a head of `head_type` will be registered. | ||
| :param objective_id: key of the new_head model. | ||
| :param objective_id: key of the new_head model used to route data samples | ||
| :param checkpoint_dir: directory to objective's checkpoints. Overrides model_name_or_path in continued training | ||
| :param head_kwargs: if transformer is automatically resolved, additional init args of the transformer, | ||
| passed to AutoModelForXY.from_pretrained() | ||
| :param new_head: if given, this would be a selected model to be registered. | ||
| :return: | ||
| :param do_merge: Whether the newly-registered model should be merged with other objective(s) modules. | ||
| :return: The module for a newly registered objective. | ||
| """ | ||
@@ -94,6 +181,8 @@ # manually-initialized head chosen for this objective will also be merged with other objectives and registered | ||
| if new_head is None: | ||
| new_head = self.load_head(self.model_name_or_path, head_type, head_kwargs) | ||
| new_head = self.load_head(self.model_name_or_path if checkpoint_dir is None else checkpoint_dir, | ||
| head_type, | ||
| load_as_peft, | ||
| head_kwargs) | ||
| # this applies to the 2nd+ -added models: they adopt the shared parameters of the first lang_module | ||
| if len(self.trainable_models) >= 1: | ||
| if do_merge and len(self.trainable_models) >= 1: | ||
| unmatched_modules = self._partially_merge_models(list(self.trainable_models.values())[0], new_head) | ||
@@ -109,3 +198,4 @@ # this can contain a deep stack of layers, hence in general, it can not be checked automatically | ||
| new_model: torch.nn.Module, | ||
| top_level: bool = True) -> List[str]: | ||
| top_level: bool = True, | ||
| no_merge_keys_containing: Optional[str] = None) -> List[str]: | ||
| """ | ||
@@ -131,23 +221,39 @@ Matches and merges shared parameters of the models. | ||
| orig_model_param = getattr(orig_model, new_param_key) | ||
| if orig_model_param.shape == new_model_param.shape and torch.all( | ||
| orig_model_param == new_model_param): | ||
| if (orig_model_param.shape == new_model_param.shape | ||
| and torch.all(orig_model_param == new_model_param)): | ||
| setattr(new_model, new_param_key, orig_model_param) | ||
| assert id(getattr(orig_model, new_param_key)) == id(getattr(new_model, new_param_key)) | ||
| else: | ||
| unmatched_modules += [new_param_key] | ||
| else: | ||
| unmatched_modules += [new_param_key] | ||
| else: | ||
| # non-leaf node -> merge in a separate branch | ||
| for child_attr, child_module in children.items(): | ||
| if hasattr(orig_model, child_attr): | ||
| unmatched_modules += LangModule._partially_merge_models(getattr(orig_model, child_attr), | ||
| getattr(new_model, child_attr), | ||
| top_level=False) | ||
| else: | ||
| if not hasattr(orig_model, child_attr): | ||
| # do not merge if the orig_model does not contain the attribute | ||
| unmatched_modules += list(dict(getattr(new_model, child_attr).named_parameters()).keys()) | ||
| continue | ||
| if (no_merge_keys_containing is not None) and (no_merge_keys_containing in child_attr): | ||
| # do not merge if the attribute is excluded | ||
| unmatched_modules += list(dict(getattr(new_model, child_attr).named_parameters()).keys()) | ||
| continue | ||
| # merge all non-excluded cases | ||
| unmatched_modules += LangModule._partially_merge_models(getattr(orig_model, child_attr), | ||
| getattr(new_model, child_attr), | ||
| top_level=False) | ||
| # check that merge-able modules now refer to the same physical address | ||
| if top_level: | ||
| for i, (new_param_key, orig_model_param) in enumerate(orig_model.named_parameters()): | ||
| if new_param_key in dict(new_model.named_parameters()): | ||
| new_model_param = new_model.get_parameter(new_param_key) | ||
| if orig_model_param.shape == new_model_param.shape and \ | ||
| torch.all(orig_model_param == new_model_param): | ||
| assert id(new_model_param) == id(orig_model_param) | ||
| if new_param_key not in dict(new_model.named_parameters()): | ||
| continue | ||
| if no_merge_keys_containing is not None and no_merge_keys_containing in new_param_key: | ||
| continue | ||
| new_model_param = new_model.get_parameter(new_param_key) | ||
| if not orig_model_param.shape == new_model_param.shape: | ||
| continue | ||
| if not torch.all(orig_model_param == new_model_param): | ||
| continue | ||
| assert id(new_model_param) == id(orig_model_param), "New objective's model was not properly merged." | ||
@@ -170,4 +276,3 @@ return unmatched_modules | ||
| list_of_model_specific_inputs = inspect.getfullargspec(selected_head_model.forward).args | ||
| model_specific_inputs = {k: v for k, v in inputs.items() | ||
| if k in list_of_model_specific_inputs and k not in ("label", "labels")} | ||
| model_specific_inputs = {k: v for k, v in inputs.items() if k in list_of_model_specific_inputs} | ||
| # including labels cause the loss to be computed twice - by objective + by HF models forward() | ||
@@ -201,1 +306,4 @@ # but labels are also used to infer decoder_input_ids of some models, so we need to pass it | ||
| module.gradient_checkpointing_enable() | ||
| def finalize(self) -> None: | ||
| self.peft_base_model = None # unset the shared base model to save memory |
@@ -15,18 +15,2 @@ import abc | ||
| @staticmethod | ||
| def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): | ||
| """ | ||
| Shift input ids one token to the right. | ||
| From transformers.modeling_bart. | ||
| """ | ||
| shifted_input_ids = input_ids.new_zeros(input_ids.shape) | ||
| shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() | ||
| shifted_input_ids[:, 0] = decoder_start_token_id | ||
| assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." | ||
| # replace possible -100 values in labels by `pad_token_id` | ||
| shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) | ||
| return shifted_input_ids | ||
| def __call__(self, | ||
@@ -85,7 +69,11 @@ features: List[Union[BatchEncoding, Dict[str, Iterable[Union[int, float]]]]], | ||
| bos_id = self.model.config.bos_token_id if self.model.config.bos_token_id is not None else 0 | ||
| pad_id = self.model.config.pad_token_id if self.model.config.pad_token_id is not None else 0 | ||
| # CLM -> shift input one token to the right | ||
| out_features["input_ids"] = self.shift_tokens_right(out_features["input_ids"], bos_id, pad_id) | ||
| # no shifting of the labels here: this happens in the corresponding loss fn | ||
| labels = out_features["input_ids"].clone() | ||
| if self.tokenizer.pad_token_id is not None: | ||
| # ignore the padded positions from the loss: without this, CLM will not converge | ||
| labels[labels == pad_id] = -100 | ||
| out_features["labels"] = labels | ||
| return out_features | ||
@@ -92,0 +80,0 @@ |
@@ -63,3 +63,4 @@ import abc | ||
| objective_args_for_head_config: Optional[Dict[str, Any]] = None, | ||
| preloaded_module: Optional[torch.nn.Module] = None) -> torch.nn.Module: | ||
| preloaded_module: Optional[torch.nn.Module] = None, | ||
| merge_objective_module: bool = True) -> torch.nn.Module: | ||
| if self.add_hidden_states_loss: | ||
@@ -75,3 +76,4 @@ # if the loss is computed also from the hidden_states, we make sure they are actually requested | ||
| objective_args_for_head_config, | ||
| preloaded_module) | ||
| preloaded_module, | ||
| merge_objective_module) | ||
@@ -119,3 +121,4 @@ def _loss_for_hidden_states(self, | ||
| similarity_or_distance_loss = torch.ones(student_hidden_unbatched.shape[0]) | ||
| similarity_or_distance_loss = torch.ones(student_hidden_unbatched.shape[0], | ||
| device=student_hidden_unbatched.device) | ||
@@ -127,3 +130,4 @@ return cosine_loss(student_hidden_unbatched, teacher_hidden_unbatched, similarity_or_distance_loss) | ||
| teacher_outputs: ModelOutput, | ||
| attn_mask: torch.LongTensor) -> torch.FloatTensor: | ||
| attn_mask: torch.LongTensor, | ||
| decoder_attn_mask: torch.BoolTensor) -> torch.FloatTensor: | ||
| if hasattr(teacher_outputs, "hidden_states"): | ||
@@ -148,3 +152,3 @@ # encoder-, or decoder-only | ||
| loss = (self._loss_for_hidden_states(student_encoder_hidden, teacher_encoder_hidden, attn_mask) + | ||
| self._loss_for_hidden_states(student_decoder_hidden, teacher_decoder_hidden, attn_mask)) | ||
| self._loss_for_hidden_states(student_decoder_hidden, teacher_decoder_hidden, decoder_attn_mask)) | ||
| else: | ||
@@ -165,2 +169,6 @@ raise ValueError("Please initialize both teacher and student model with `output_hidden_states=True`") | ||
| teacher_inputs = inspect.getfullargspec(self.teacher_model.forward).args | ||
| device = student_logits.device | ||
| if self.teacher_model.device != device: | ||
| self.teacher_model = self.teacher_model.to(device) | ||
| with torch.no_grad(): | ||
@@ -170,6 +178,13 @@ teacher_outputs = self.teacher_model(**{k: v for k, v in inputs.items() if k in teacher_inputs}) | ||
| non_ignored_labels: Optional[torch.BoolTensor] = None # used only in the case of encoder-decoder models | ||
| if self.restrict_loss_to_mask: | ||
| # pick only the predictions of tokens on the attended positions (i.e. ignore the others) | ||
| attn_mask_reshaped = inputs["attention_mask"].unsqueeze(-1).expand_as(student_logits).bool() | ||
| if self.compatible_head_model.config.is_encoder_decoder: | ||
| # encoder-decoder -> attention_mask actually applies to labels | ||
| # we infer the labels attention mask from positions not ignored in the loss | ||
| non_ignored_labels: torch.BoolTensor = (0 < inputs["labels"]) < self.tokenizer.vocab_size | ||
| attn_mask_reshaped = non_ignored_labels.unsqueeze(-1).expand_as(student_logits).bool() | ||
| else: | ||
| # encoder-only or decoder-only model -> attention mask applies to labels | ||
| attn_mask_reshaped = inputs["attention_mask"].unsqueeze(-1).expand_as(student_logits).bool() | ||
| student_logits_flat = torch.masked_select(student_logits, attn_mask_reshaped) | ||
@@ -196,3 +211,4 @@ student_logits_unbatched = student_logits_flat.reshape(-1, student_logits.shape[-1]) | ||
| hidden_loss = self._hidden_states_loss(student_outputs, teacher_outputs, inputs["attention_mask"]) | ||
| hidden_loss = self._hidden_states_loss(student_outputs, teacher_outputs, | ||
| inputs["attention_mask"], non_ignored_labels) | ||
| hidden_loss_scaled = self.hidden_cossim_loss_weight * hidden_loss | ||
@@ -199,0 +215,0 @@ |
| import abc | ||
| import itertools | ||
| import logging | ||
| import os.path | ||
| from functools import partial | ||
| from typing import List, Union, Optional, Iterable, Tuple, Dict, Sequence, Any, Iterator | ||
@@ -23,3 +25,3 @@ | ||
| compatible_head: Head | ||
| given_id: Optional[str] | ||
| given_id: Optional[str] = "" | ||
| epoch: int | ||
@@ -40,5 +42,11 @@ num_steps: int | ||
| evaluators: Dict[str, List[EvaluatorBase]] | ||
| data_iteration_offset: int | ||
| routing_id: torch.Tensor | ||
| num_samples_per_log: Dict[str, int] | ||
| num_samples_to_prefetch: int = 10 | ||
| peft_objective: bool | ||
| save_objective_module: bool | ||
| def __init__(self, | ||
@@ -51,4 +59,8 @@ lang_module: LangModule, | ||
| val_evaluators: Sequence[EvaluatorBase] = (), | ||
| train_dataset_length: Optional[int] = None, | ||
| val_dataset_length: Optional[int] = None, | ||
| share_other_objective_head: Optional["Objective"] = None, | ||
| objective_module: Optional[torch.nn.Module] = None, | ||
| merge_objective_module: bool = True, | ||
| save_objective_module: bool = True, | ||
| objective_args_for_head_config: Dict[str, Any] = {}, | ||
@@ -59,3 +71,6 @@ objective_id: Optional[str] = "", | ||
| max_samples_per_eval_log: int = 10000, | ||
| remember_last_input: Optional[bool] = False): | ||
| data_iteration_offset: int = 0, | ||
| prefetch_in_parallel_thread: bool = False, | ||
| remember_last_input: Optional[bool] = False, | ||
| peft_objective: Optional[bool] = False): | ||
| """ | ||
@@ -65,3 +80,3 @@ Shared initialisation logic of every Objective. | ||
| initialises state variables for logging, registers evaluators, | ||
| initialises data set inputs and labels either from path to .txt files, or a lists of strings. | ||
| initialises data set inputs and labels either from path to .txt files, or a lists/iterables of strings. | ||
@@ -74,4 +89,8 @@ :param lang_module: LangModule to register a model of this objective into. | ||
| :param val_evaluators: Evaluators to be called on every evaluation step on validation outputs. | ||
| :param train_dataset_length: Circumvent auto inference of the train dataset length and set it manually. | ||
| :param val_dataset_length: Circumvent auto inference of the validation dataset length and set it manually. | ||
| :param share_other_objective_head: If given, this objective will share module with other given objective. | ||
| :param objective_module: If given, this module will be registered for this objective. | ||
| :param merge_objective_module: If to merge the newly initialized or passed objective's module with others. | ||
| :param save_objective_module: If to separately save the module of this objective on calling save_model. | ||
| :param objective_args_for_head_config: Extra arguments that can be passed to .from_pretrained() on head init. | ||
@@ -88,7 +107,9 @@ :param objective_id: Identifier of this objective, used in logging and checkpoints persistence. | ||
| """ | ||
| self.routing_id = torch.tensor(id(self)) | ||
| self.batch_size = batch_size | ||
| self.tokenizer = lang_module.tokenizer | ||
| self.given_id = objective_id | ||
| self.objective_id = objective_id | ||
| self.peft_objective = peft_objective | ||
| self.loss_weight = loss_weight | ||
| self.num_steps = 0 | ||
@@ -98,6 +119,2 @@ self.remember_last_input = remember_last_input | ||
| self.compatible_head_model = self.register_compatible_head_model(lang_module, | ||
| share_other_objective_head, | ||
| objective_args_for_head_config, | ||
| objective_module) | ||
| self.epoch = 0 | ||
@@ -109,3 +126,13 @@ self.dataset_length = {"train": 0, "eval": 0} | ||
| self.max_samples_per_log = {"train": max_samples_per_log, "eval": max_samples_per_eval_log} | ||
| self.data_iteration_offset = 0 | ||
| self.prefetch_in_parallel_thread = prefetch_in_parallel_thread | ||
| # register_compatible_head_model also sets the dataset iterator in continued training | ||
| self.compatible_head_model = self.register_compatible_head_model(lang_module, | ||
| share_other_objective_head, | ||
| objective_args_for_head_config, | ||
| objective_module, | ||
| merge_objective_module) | ||
| self.save_objective_module = save_objective_module | ||
| if data_iteration_offset: # can override obtained trainer_state["global_step"] in continued training | ||
| self.data_iteration_offset = data_iteration_offset | ||
| self.progressbar = {} | ||
@@ -119,9 +146,12 @@ | ||
| if isinstance(texts_or_path, str): | ||
| self._check_supported_data_source_format(texts_or_path) | ||
| self.texts_path = texts_or_path | ||
| with open(self.texts_path) as f: | ||
| self.dataset_length["train"] = len(f.readlines()) | ||
| else: | ||
| self.texts = texts_or_path | ||
| self.dataset_length["train"] = len(self.texts) | ||
| if train_dataset_length is None: | ||
| self.dataset_length["train"] = self._compute_data_source_length(texts_or_path) | ||
| else: | ||
| self.dataset_length["train"] = train_dataset_length | ||
| for split, given_evaluators in zip(("train", "eval"), (train_evaluators, val_evaluators)): | ||
@@ -139,9 +169,44 @@ for given_evaluator in given_evaluators: | ||
| if isinstance(val_texts_or_path, str): | ||
| self._check_supported_data_source_format(val_texts_or_path) | ||
| self.val_texts_path = val_texts_or_path | ||
| with open(self.val_texts_path) as f: | ||
| self.dataset_length["eval"] = len(f.readlines()) | ||
| else: | ||
| self.val_texts = val_texts_or_path | ||
| self.dataset_length["eval"] = len(self.val_texts) | ||
| if val_dataset_length is None: | ||
| self.dataset_length["eval"] = self._compute_data_source_length(val_texts_or_path) | ||
| else: | ||
| self.dataset_length["eval"] = val_dataset_length | ||
| def _check_supported_data_source_format(self, path: str) -> None: | ||
| if not os.path.exists(path): | ||
| raise FileNotFoundError("Objective %s: Given path '%s' does not exist" % (self, path)) | ||
| # when the passed data source is a file, we check that it is in a supported format: | ||
| # we support .txt and .tar.gz files | ||
| supported_file_formats = ['.txt', '.gz'] | ||
| if not any(path.endswith(suffix) for suffix in supported_file_formats): | ||
| logger.warning("Objective %s's given {val_}texts_or_path `%s` is not a List " | ||
| "and does not end with one of supported suffixes: ['.txt', '.gz']." | ||
| "We'll assume that the file is a line-separated plaintext file." % (self, path)) | ||
| def _compute_data_source_length(self, texts_or_path: Union[str, List[str]]) -> int: | ||
| if isinstance(texts_or_path, str): | ||
| if texts_or_path.endswith('.gz'): | ||
| import io | ||
| import gzip | ||
| with io.TextIOWrapper(io.BufferedReader(gzip.open(texts_or_path))) as f: # type: ignore | ||
| return sum(1 for _ in f) # more efficient line count | ||
| else: | ||
| with open(texts_or_path, "rb") as f: | ||
| return sum(1 for _ in f) # more efficient line count | ||
| elif isinstance(texts_or_path, list): | ||
| return len(texts_or_path) | ||
| else: | ||
| raise ValueError("Objective %s's data format (%s) is not supported. " | ||
| "Please pass in a List[str], or str denoting a path to a file." | ||
| % (self, type(texts_or_path))) | ||
| def per_objective_log(self, split: str) -> Dict[str, float]: | ||
@@ -154,2 +219,5 @@ """ | ||
| out_logs = {} | ||
| if split == "eval" and self.val_texts is None and self.val_texts_path is None: | ||
| logger.warning("Skipping evaluation for %s" % self) | ||
| return out_logs | ||
| # aggregate per-progress_bar-steps, or per-evaluation-steps, keep the results of unprocessed evaluations | ||
@@ -165,3 +233,2 @@ loss_history = self.loss_history[split][-self.max_samples_per_log[split]:] | ||
| dataset = self.get_dataset(split, 0, self.compatible_head_model.device, | ||
| firstn=self.max_samples_per_log[split], | ||
| add_oid=False, | ||
@@ -257,2 +324,3 @@ is_training_dataset=False) | ||
| if self.progressbar[split] is not None: | ||
| self.progressbar[split].update(1) | ||
| self.progressbar[split].set_postfix(refresh=False, split=split, loss=loss.item(), epoch=self.epoch) | ||
@@ -277,5 +345,4 @@ | ||
| split: str, | ||
| objective_i: int, | ||
| device: Union[str, torch.device], | ||
| firstn: Optional[int] = None, | ||
| objective_i: Optional[int] = 0, | ||
| device: Optional[Union[str, torch.device]] = None, | ||
| add_oid: bool = True, | ||
@@ -287,7 +354,8 @@ is_training_dataset: bool = True, | ||
| :param split: A split of the retrieved dataset. `train` or `eval`. | ||
| :param objective_i: Rank of this objective in schedule. Used only to properly set up progress bar. | ||
| :param objective_i: Objective's rank used only to properly set up parallel progress bars. | ||
| :param device: Device to transfer this data set to. | ||
| :param firstn: If given, a number of the retrieved items from the dataset. | ||
| :param add_oid: Whether to append objective id to the match. Required for forward pass over LangModule. | ||
| :param is_training_dataset: Whether this dataset is used for training -> if to update the epochs counter. | ||
| Note that training dataset can also be iterated outside main training loop. | ||
| :param show_progressbar: Whether to maintain a dataset iterator progress bar for this objective. | ||
@@ -301,22 +369,13 @@ :return: TransformerAdaptationDataset wrapping a data set of this objective. | ||
| if show_progressbar: | ||
| self.progressbar[split] = trange(self.dataset_length[split] // self.batch_size, | ||
| desc=str(self), | ||
| unit="batches", | ||
| position=objective_i, | ||
| leave=True) | ||
| self.progressbar[split].set_postfix(refresh=False, split=split, epoch=self.epoch, loss=-1) | ||
| else: | ||
| # we do not update loss, if no progress bar is pertained | ||
| self.progressbar[split] = None | ||
| inputs_iter = self._get_inputs_iterator(split) | ||
| def _sample_to_device(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: | ||
| return {k: v.to(device) if k != "oid" else v for k, v in sample.items()} | ||
| def _sample_to_device(chosen_device: Optional[Union[str, torch.device]], | ||
| sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> Dict[str, torch.Tensor]: | ||
| if chosen_device is None: | ||
| # default device is a device of the model assigned to this objective, if it is set | ||
| # if it is not, we resort to "cpu" | ||
| # in classic training, the model is always assigned when the dataset is requested | ||
| chosen_device = self.compatible_head_model.device if self.compatible_head_model is not None else "cpu" | ||
| return {k: v.to(chosen_device) for k, v in sample.items()} | ||
| def _add_oid(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: | ||
| sample["oid"] = torch.tensor(id(self)) | ||
| return sample | ||
| def _remember_input(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: | ||
@@ -330,17 +389,47 @@ self.last_input = sample | ||
| device_inputs_iter = map(_sample_to_device, inputs_iter) | ||
| def _add_oid(sample: Union[BatchEncoding, Dict[str, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: | ||
| sample["oid"] = self.routing_id | ||
| return sample | ||
| device_inputs_iter = map(partial(_sample_to_device, device), inputs_iter) | ||
| if split == "eval" and self.max_samples_per_log["eval"] is not None: | ||
| device_inputs_iter = itertools.islice(device_inputs_iter, self.max_samples_per_log["eval"]) | ||
| self.dataset_length["eval"] = self.max_samples_per_log["eval"] * self.batch_size | ||
| if add_oid: | ||
| device_inputs_iter = map(_add_oid, device_inputs_iter) | ||
| if firstn is not None and firstn < self.dataset_length[split]: | ||
| device_inputs_iter = itertools.islice(device_inputs_iter, firstn) | ||
| if self.remember_last_input: | ||
| device_inputs_iter = map(_remember_input, device_inputs_iter) | ||
| if self.prefetch_in_parallel_thread: | ||
| from prefetch_generator import BackgroundGenerator | ||
| device_inputs_iter = BackgroundGenerator(device_inputs_iter, max_prefetch=self.num_samples_to_prefetch) | ||
| # Support for continued training: | ||
| # if nonempty dataset AND this is a first train iteration, fast-forward data iteration to the self.offset_steps | ||
| should_offset_dataset = self.dataset_length[split] and (split == "train" and self.epoch == 1) | ||
| dataset_samples_offset = self.data_iteration_offset % self.dataset_length[split] if should_offset_dataset else 0 | ||
| # adjust the current epoch accordingly | ||
| offset_epoch = (self.data_iteration_offset // self.dataset_length[split]) | ||
| if offset_epoch: | ||
| self.epoch = offset_epoch + 1 | ||
| # do not apply the offset again in the next epochs | ||
| self.data_iteration_offset = 0 | ||
| if show_progressbar: | ||
| device_inputs_iter = map(_update_pbar, device_inputs_iter) | ||
| # set up a new progressbar object | ||
| self.progressbar[split] = trange(self.dataset_length[split] // self.batch_size, | ||
| initial=dataset_samples_offset, | ||
| desc=str(self), | ||
| unit="batches", | ||
| position=objective_i, | ||
| leave=True) | ||
| self.progressbar[split].set_postfix(refresh=False, split=split, epoch=self.epoch, loss=-1) | ||
| else: | ||
| # we do not update loss, if no progress bar is pertained | ||
| self.progressbar[split] = None | ||
| return TransformerAdaptationDataset(device_inputs_iter, self.dataset_length[split]) | ||
| return TransformerAdaptationDataset(device_inputs_iter, self.dataset_length[split], dataset_samples_offset) | ||
@@ -415,3 +504,4 @@ def compute_loss_on_last_sample(self) -> torch.FloatTensor: | ||
| objective_args_for_head_config: Optional[Dict[str, Any]] = None, | ||
| preloaded_module: Optional[torch.nn.Module] = None) -> torch.nn.Module: | ||
| preloaded_module: Optional[torch.nn.Module] = None, | ||
| do_merge: bool = True) -> torch.nn.Module: | ||
| """ | ||
@@ -432,2 +522,10 @@ Resolves a model of this objective in given lang_module. Either requests LangModule to provide model with | ||
| if (self.peft_objective and "peft_config" not in head_config) or \ | ||
| (not self.peft_objective and "peft_config" in head_config): | ||
| raise ValueError("When loading an objective with a PEFT module, you must both set the `peft_objective=True`" | ||
| " *and* provide a `peft_config` in objective_args_for_head_config argument.") | ||
| # Support for continued training: | ||
| checkpoint_dir = None | ||
| possible_checkpoint_path = os.path.join(lang_module.model_name_or_path, str(self)) | ||
| if other_objective is not None: | ||
@@ -437,5 +535,27 @@ logger.warning("Objective %s will use %s head of %s objective", | ||
| preloaded_module = other_objective.compatible_head_model | ||
| elif preloaded_module is not None: | ||
| logger.warning("Objective %s will use the pre-defined model given in `objective_module` parameter.", self) | ||
| elif os.path.exists(possible_checkpoint_path): | ||
| logger.warning("Reloading objective %s's module from checkpoint %s", str(self), possible_checkpoint_path) | ||
| checkpoint_dir = possible_checkpoint_path | ||
| return lang_module.load_training_head(self.compatible_head, str(id(self)), head_config, preloaded_module) | ||
| # if this is a checkpoint path (not a saved lang_module), adjust data iterator according to trainer_state | ||
| trainer_state_path = os.path.join(lang_module.model_name_or_path, "trainer_state.json") | ||
| if os.path.exists(trainer_state_path): | ||
| from transformers import TrainerState | ||
| trainer_state = TrainerState.load_from_json(trainer_state_path) | ||
| logger.warning("Data iteration of %s will continue on a step %s.", self, trainer_state.global_step) | ||
| self.data_iteration_offset = trainer_state.global_step | ||
| else: | ||
| logger.warning("No checkpoint found on %s. Attempting to load a model from '%s'.", | ||
| possible_checkpoint_path, lang_module.model_name_or_path) | ||
| return lang_module.load_training_head(self.compatible_head, | ||
| self.peft_objective, | ||
| str(id(self)), | ||
| checkpoint_dir, | ||
| head_config, | ||
| preloaded_module, | ||
| do_merge) | ||
| def __str__(self) -> str: | ||
@@ -446,4 +566,4 @@ """ | ||
| """ | ||
| if self.given_id: | ||
| return str("%s-%s" % (self.given_id, self.__class__.__name__)) | ||
| if self.objective_id: | ||
| return str("%s-%s" % (self.objective_id, self.__class__.__name__)) | ||
| else: | ||
@@ -492,2 +612,4 @@ return self.__class__.__name__ | ||
| if isinstance(labels_or_path, str): | ||
| # data source is a file: we support .txt and .tar.gz files | ||
| self._check_supported_data_source_format(labels_or_path) | ||
| self.labels_path = labels_or_path | ||
@@ -499,2 +621,3 @@ else: | ||
| if isinstance(val_labels_or_path, str): | ||
| self._check_supported_data_source_format(val_labels_or_path) | ||
| self.val_labels_path = val_labels_or_path | ||
@@ -506,2 +629,3 @@ else: | ||
| if isinstance(text_pair_or_path, str): | ||
| self._check_supported_data_source_format(text_pair_or_path) | ||
| self.text_pair_path = text_pair_or_path | ||
@@ -513,2 +637,3 @@ else: | ||
| if isinstance(val_text_pair_or_path, str): | ||
| self._check_supported_data_source_format(val_text_pair_or_path) | ||
| self.val_text_pair_path = val_text_pair_or_path | ||
@@ -524,3 +649,4 @@ else: | ||
| objective_args_for_head_config: Optional[Dict[str, Any]] = None, | ||
| preloaded_module: Optional[torch.nn.Module] = None) -> torch.nn.Module: | ||
| preloaded_module: Optional[torch.nn.Module] = None, | ||
| merge_objective_module: bool = True) -> torch.nn.Module: | ||
| """ | ||
@@ -551,4 +677,4 @@ Additionally adds labels into a configuration of this objective's model in lang_module. | ||
| return super().register_compatible_head_model(lang_module, other_objective, | ||
| objective_args_for_head_config, preloaded_module) | ||
| return super().register_compatible_head_model(lang_module, other_objective, objective_args_for_head_config, | ||
| preloaded_module, merge_objective_module) | ||
@@ -555,0 +681,0 @@ def _get_inputs_iterator(self, split: str) -> Iterator[Union[BatchEncoding, Dict[str, torch.Tensor]]]: |
| import abc | ||
| from typing import List, Optional, Iterable, Dict, Iterator, Callable, Any, Union | ||
| from typing import List, Optional, Iterable, Dict, Iterator, Callable, Union | ||
@@ -7,3 +7,2 @@ import torch | ||
| from ..lang_module import LangModule | ||
| from ..objectives.objective_base import SupervisedObjective, Objective | ||
@@ -23,6 +22,18 @@ from ..utils import Head | ||
| **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| self.source_lang_id = source_lang_id | ||
| self.target_lang_id = target_lang_id | ||
| super().__init__(*args, **kwargs) | ||
| if hasattr(self.tokenizer, "lang_code_to_id") and self.source_lang_id is not None: | ||
| assert self.source_lang_id in self.tokenizer.vocab, \ | ||
| ("Objective %s's 'src_lang' is not in its tokenizer's vocabulary. " | ||
| "This would cause wrong data encodings." % self.source_lang_id) | ||
| self.tokenizer.src_lang = self.source_lang_id | ||
| if hasattr(self.tokenizer, "lang_code_to_id") and self.target_lang_id is not None: | ||
| assert self.target_lang_id in self.tokenizer.vocab, \ | ||
| ("Objective %s's 'tgt_lang' is not in its tokenizer's vocabulary. " | ||
| "This would cause wrong data encodings." % self.tokenizer.tgt_lang) | ||
| self.tokenizer.tgt_lang = self.target_lang_id | ||
| def _get_seq2seq_collated_iterator(self, | ||
@@ -42,9 +53,5 @@ source_texts: Iterable[str], | ||
| self.tokenizer.tgt_lang = self.target_lang_id | ||
| sample_features = self.tokenizer(source_text, truncation=True) | ||
| sample_features = dict(self.tokenizer(source_text, text_target=target_text, truncation=True)) | ||
| with self.tokenizer.as_target_tokenizer(): | ||
| sample_targets = self.tokenizer(target_text, truncation=True) | ||
| features_batch.append({"input_ids": sample_features.input_ids, | ||
| "attention_mask": sample_features.attention_mask, | ||
| "labels": sample_targets.input_ids}) | ||
| features_batch.append(sample_features) | ||
| if len(features_batch) == self.batch_size: | ||
@@ -112,21 +119,5 @@ yield self.collator(features_batch) | ||
| def register_compatible_head_model(self, | ||
| lang_module: LangModule, | ||
| other_objective: Optional["Objective"] = None, | ||
| objective_args_for_head_config: Optional[Dict[str, Any]] = None, | ||
| preloaded_module: Optional[torch.nn.Module] = None) -> torch.nn.Module: | ||
| head_module = super().register_compatible_head_model(lang_module, other_objective, | ||
| objective_args_for_head_config, preloaded_module) | ||
| assert hasattr(head_module, "prepare_decoder_input_ids_from_labels"), \ | ||
| "No head of the loaded LangModule is compatible with %s objective! " \ | ||
| "\nNote that the module compatible with " \ | ||
| "Sequence2SequenceMixin \nmust have `prepare_decoder_input_ids_from_labels` method, " \ | ||
| "see e.g. \ntransformers.BartModel." % self | ||
| return head_module | ||
| class Sequence2Sequence(Sequence2SequenceMixin, SupervisedObjective): | ||
| pass |
+83
-27
| import abc | ||
| import logging | ||
| from enum import Enum | ||
| from typing import Dict, Iterable, Iterator, Optional | ||
| import os | ||
| import torch | ||
| import peft | ||
| from torch.utils.data import IterableDataset | ||
| import transformers | ||
| from transformers import BatchEncoding, TrainingArguments | ||
| logger = logging.getLogger() | ||
| class Head(Enum): | ||
@@ -30,2 +37,26 @@ SEQ_CLASSIFICATION = 1 | ||
| class SavingStrategy(Enum): | ||
| ALL_OBJECTIVES = 1 | ||
| FIRST_OBJECTIVE = 2 | ||
| FINISHED_OBJECTIVES = 3 | ||
| HEAD_TO_MODEL_CLS = { | ||
| Head.SEQ_CLASSIFICATION: {"full": transformers.AutoModelForSequenceClassification, | ||
| "peft": peft.PeftModelForSequenceClassification}, | ||
| Head.TOKEN_CLASSIFICATION: {"full": transformers.AutoModelForTokenClassification, | ||
| "peft": peft.PeftModelForTokenClassification}, | ||
| Head.SEQ2SEQ: {"full": transformers.AutoModelForSeq2SeqLM, | ||
| "peft": peft.PeftModelForSeq2SeqLM}, | ||
| Head.CLM: {"full": transformers.AutoModelForCausalLM, | ||
| "peft": peft.PeftModelForCausalLM}, | ||
| Head.MLM: {"full": transformers.AutoModelForMaskedLM, | ||
| "peft": NotImplemented}, | ||
| Head.QA: {"full": transformers.AutoModelForQuestionAnswering, | ||
| "peft": peft.PeftModelForQuestionAnswering} | ||
| } | ||
| PEFT_BASE_MODEL_CHECKPOINT_SUBDIR = "base_model" | ||
| class AdaptationDataset(IterableDataset, abc.ABC): | ||
@@ -37,3 +68,8 @@ """ | ||
| def __init__(self, length: Optional[int] = None): | ||
| self.length = length | ||
| self.world_size = int(os.environ.get("WORLD_SIZE", 1)) | ||
| if self.world_size > 1: | ||
| logger.warning("World size for data sampling: %s" % self.world_size) | ||
| self.length = length // self.world_size | ||
| else: | ||
| self.length = length | ||
@@ -48,13 +84,25 @@ def __getitem__(self, index: int) -> BatchEncoding: | ||
| def iter_text_file_per_line(path: str) -> Iterable[str]: | ||
| with open(path) as f: | ||
| for line in f: | ||
| yield line.strip() | ||
| """ | ||
| Iterate over the lines of a file on a given path. | ||
| At this point, `path` is checked to be of a supported format. | ||
| :param path: file path | ||
| """ | ||
| if path.endswith(".gz"): | ||
| import gzip | ||
| import io | ||
| with io.TextIOWrapper(io.BufferedReader(gzip.open(path))) as file: # type: ignore | ||
| for line in file: | ||
| yield line.strip() | ||
| else: | ||
| # assumes plain, newline-separated text file | ||
| with open(path) as f: | ||
| for line in f: | ||
| yield line.strip() | ||
| class TransformerAdaptationDataset(AdaptationDataset): | ||
| def __init__( | ||
| self, | ||
| batch_encoding_params: Iterable[Dict[str, torch.LongTensor]], | ||
| length: Optional[int] = None, | ||
| ): | ||
| def __init__(self, | ||
| batch_encoding_params: Iterable[Dict[str, torch.LongTensor]], | ||
| length: Optional[int] = None, | ||
| offset: int = 0): | ||
| """ | ||
@@ -65,2 +113,3 @@ :param batch_encoding_params: Arguments to be passed to BatchEncoding (input_ids, attention_mask, labels) | ||
| self.batch_encoding_params = batch_encoding_params | ||
| self.offset = offset | ||
@@ -76,5 +125,9 @@ def __iter__(self) -> Iterator[Dict[str, torch.LongTensor]]: | ||
| for i, encoded_sample in enumerate(self.batch_encoding_params): | ||
| if worker_info is not None: | ||
| # fast-forward the self.offset steps in continued training | ||
| if i < self.offset: | ||
| continue | ||
| if self.world_size > 1 and worker_info is not None: | ||
| # multi-gpu DataParallel | ||
| if (i - worker_info.id) % worker_info.num_workers == 0: | ||
| if i % self.world_size == worker_info.id: | ||
| # sample modulo number of all workers match this worker rank | ||
@@ -87,3 +140,2 @@ yield encoded_sample | ||
| class AdaptationArguments(TrainingArguments): | ||
| fixed_adaptation_args = { | ||
@@ -101,24 +153,28 @@ "per_device_train_batch_size": 1, # batching is done by Objective, no two distinct batches | ||
| def __init__( | ||
| self, | ||
| stopping_strategy: StoppingStrategy, | ||
| stopping_patience: Optional[int] = 10, | ||
| also_log_converged_objectives: Optional[bool] = True, | ||
| **kwargs | ||
| ): | ||
| def __init__(self, | ||
| stopping_strategy: StoppingStrategy, | ||
| stopping_patience: Optional[int] = 10, | ||
| saving_strategy: SavingStrategy = SavingStrategy.ALL_OBJECTIVES, | ||
| also_log_converged_objectives: Optional[bool] = True, | ||
| save_peft_base_model: bool = False, | ||
| **kwargs): | ||
| """ | ||
| Adds Adaptor-specific arguments to standard HF's TrainingArguments | ||
| :param stopping_strategy: A strategy to decide whether to stop training, based on the states of all objectives | ||
| :param stopping_patience: How many global steps to wait before stopping the training | ||
| :param saving_strategy: A strategy to choose the objectives for which we persist the models in checkpoints. | ||
| :param also_log_converged_objectives: Whether to perform evaluations also for already stopped objectives | ||
| :param save_peft_base_model: Whether to also persist the base model when training some objective(s) with PEFT. | ||
| """ | ||
| # novel arguments, w.r.t. original TrainingArguments | ||
| self.stopping_strategy = stopping_strategy | ||
| self.stopping_patience = stopping_patience | ||
| self.saving_strategy = saving_strategy | ||
| self.log_converged_objectives = also_log_converged_objectives | ||
| self.save_peft_base_model = save_peft_base_model | ||
| # adjustments of the defaults expected by Scheduler | ||
| unexpected_adjusted_args = [ | ||
| arg for arg in kwargs.keys() if arg in self.fixed_adaptation_args.keys() | ||
| ] | ||
| unexpected_adjusted_args = [arg for arg in kwargs.keys() if arg in self.fixed_adaptation_args.keys()] | ||
| if unexpected_adjusted_args: | ||
| raise ValueError( | ||
| "You should not set these TrainingArgs for Adaptation: %s" | ||
| % unexpected_adjusted_args | ||
| ) | ||
| raise ValueError("You should not set these TrainingArgs for Adaptation: %s" % unexpected_adjusted_args) | ||
@@ -125,0 +181,0 @@ # set default values to fixed args |
+5
-2
| Metadata-Version: 2.1 | ||
| Name: adaptor | ||
| Version: 0.2.4 | ||
| Version: 0.2.5 | ||
| Summary: Adaptor: Objective-centric Adaptation Framework for Language Models. | ||
@@ -14,3 +14,3 @@ Home-page: https://github.com/gaussalgo/adaptor | ||
| Classifier: Programming Language :: Python :: 3.8 | ||
| Requires-Python: >=3.8 | ||
| Requires-Python: >=3.7 | ||
| Description-Content-Type: text/markdown | ||
@@ -22,2 +22,5 @@ License-File: LICENSE | ||
| Requires-Dist: accelerate>=0.20.1 | ||
| Requires-Dist: peft<0.13.0,>=0.10.0 | ||
| Requires-Dist: prefetch-generator>=1.0.3 | ||
| Requires-Dist: numpy<1.24 | ||
| Provides-Extra: generative | ||
@@ -24,0 +27,0 @@ Requires-Dist: sacrebleu; extra == "generative" |
+7
-4
@@ -12,3 +12,3 @@ #!/usr/bin/env python | ||
| name="adaptor", | ||
| version='0.2.4', | ||
| version='0.2.5', | ||
| description="Adaptor: Objective-centric Adaptation Framework for Language Models.", | ||
@@ -27,5 +27,4 @@ long_description_content_type="text/markdown", | ||
| url="https://github.com/gaussalgo/adaptor", | ||
| python_requires=">=3.8", | ||
| python_requires=">=3.7", | ||
| license="MIT", | ||
| license_files=["LICENSE"], | ||
| packages=find_packages(include=["adaptor", "adaptor.*"]), | ||
@@ -38,3 +37,6 @@ include_package_data=True, | ||
| "sentencepiece", | ||
| "accelerate>=0.20.1" | ||
| "accelerate>=0.20.1", | ||
| "peft>=0.10.0,<0.13.0", | ||
| "prefetch-generator>=1.0.3", | ||
| "numpy<1.24" # constrained by integration of a prism metric, can be removed, once prism is deprecated | ||
| ], | ||
@@ -51,2 +53,3 @@ test_require=[ | ||
| "protobuf<=3.20.1", | ||
| # "omegaconf>=2.2" # previous versions are incompatible with pip<25 for unsupported deps syntax ('>=5.1.*') | ||
| ], | ||
@@ -53,0 +56,0 @@ "examples": [ |
| include LICENSE | ||
| include README.md |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
230604
10.81%3211
11.18%33
-2.94%