adaptor
Advanced tools
| import abc | ||
| import inspect | ||
| from typing import Optional, Union, Dict, Any, Tuple | ||
| import torch | ||
| from torch.nn import CrossEntropyLoss, CosineEmbeddingLoss | ||
| from torch.nn.functional import log_softmax, softmax | ||
| from transformers import BatchEncoding, PreTrainedModel | ||
| from transformers.utils import ModelOutput | ||
| from adaptor.lang_module import LangModule | ||
| from adaptor.objectives.objective_base import Objective | ||
| class Distillation(Objective, abc.ABC): | ||
| """ | ||
| Model-agnostic implementation of distillation, as introduced in DistilBERT paper: https://arxiv.org/abs/1910.01108 | ||
| This implementation will work out-of-box with default parameters for models with the same prediction dimensionality | ||
| (= number of categories, or vocab size for LMs). | ||
| Note that objectives can occasionally produce more than one output: currently only in ExtractiveQA; | ||
| in such cases, the model prediction can be flattened in a custom wrapper. | ||
| """ | ||
| def __init__(self, *args, | ||
| teacher_model: PreTrainedModel, | ||
| temperature: int = 1, | ||
| logits_ce_loss_weight: int = 1, | ||
| hidden_cossim_loss_weight: int = 1, | ||
| add_hidden_states_loss: bool = False, | ||
| restrict_loss_to_mask: bool = False, | ||
| **kwargs): | ||
| """ | ||
| See the distillation parameters description in https://arxiv.org/abs/1910.01108. | ||
| :param teacher_model: A model to distill the prediction from. | ||
| :param temperature: Distillation intensity; The smaller, the stronger. See DistilBERT paper for details. | ||
| :param logits_ce_loss_weight: Relative weight of Logits' cross-entropy loss. | ||
| :param hidden_cossim_loss_weight: Relative weights of hidden states loss. | ||
| :param add_hidden_states_loss: Whether to also include hidden states loss. | ||
| Note that hidden states' loss will work only when distilling from the model | ||
| with same hidden states' dimensionality. Defaults to false. | ||
| :param restrict_loss_to_mask: Whether to compute selected losses only from the attended positions | ||
| (set to True), or from all positions (default, set to False) | ||
| """ | ||
| self.teacher_model = teacher_model | ||
| self.temperature = temperature | ||
| self.logits_ce_loss_weight = logits_ce_loss_weight | ||
| self.hidden_cossim_loss_weight = hidden_cossim_loss_weight | ||
| self.restrict_loss_to_mask = restrict_loss_to_mask | ||
| self.add_hidden_states_loss = add_hidden_states_loss | ||
| if add_hidden_states_loss: | ||
| # in this case, we'll need the teacher to yield the hidden states in the output | ||
| self.teacher_model.config.output_hidden_states = True | ||
| super().__init__(*args, **kwargs) | ||
| def register_compatible_head_model(self, | ||
| lang_module: LangModule, | ||
| other_objective: Optional["Objective"] = None, | ||
| objective_args_for_head_config: Optional[Dict[str, Any]] = None, | ||
| preloaded_module: Optional[torch.nn.Module] = None) -> torch.nn.Module: | ||
| if self.add_hidden_states_loss: | ||
| # if the loss is computed also from the hidden_states, we make sure they are actually requested | ||
| if objective_args_for_head_config is not None: | ||
| objective_args_for_head_config["output_hidden_states"] = True | ||
| else: | ||
| objective_args_for_head_config = {"output_hidden_states": True} | ||
| return super().register_compatible_head_model(lang_module, | ||
| other_objective, | ||
| objective_args_for_head_config, | ||
| preloaded_module) | ||
| def _loss_for_hidden_states(self, | ||
| student_hidden: Tuple[torch.FloatTensor], | ||
| teacher_hidden: Tuple[torch.FloatTensor], | ||
| attn_mask: torch.LongTensor, | ||
| teacher_select_method: str = "alternate") -> torch.FloatTensor: | ||
| assert student_hidden[0].shape[-1] == teacher_hidden[0].shape[-1], \ | ||
| "If adding loss of the hidden states, student and teacher must have embeddings of the same dimension." | ||
| cosine_loss = CosineEmbeddingLoss(reduction="mean") | ||
| if teacher_select_method == "alternate": | ||
| teacher_student_ratio = len(teacher_hidden) / len(student_hidden) | ||
| assert teacher_student_ratio >= 1.0, "Number of teacher's hidden states (%s) " \ | ||
| "must bigger than the student's: (%s)" \ | ||
| % (len(teacher_hidden), len(student_hidden)) | ||
| # select every n-th hidden state of the teacher, as in DistilBERT implementation | ||
| # if the former size is not the multiplier of the latter, we select approximately proportional layers | ||
| selected_teacher_hs = torch.arange(0, len(teacher_hidden), teacher_student_ratio).long() | ||
| teacher_hidden_selected = [hidden for i, hidden in enumerate(teacher_hidden) if i in selected_teacher_hs] | ||
| else: | ||
| raise ValueError("Unknown teacher_select_method: %s" % teacher_select_method) | ||
| student_hidden = torch.vstack([h.unsqueeze(0) for h in student_hidden]) | ||
| teacher_hidden_selected = torch.vstack([h.unsqueeze(0) for h in teacher_hidden_selected]) | ||
| if self.restrict_loss_to_mask: | ||
| # compute loss only from the attended positions | ||
| attn_mask_reshaped = attn_mask.unsqueeze(-1).unsqueeze(0).expand_as(student_hidden).bool() | ||
| student_hidden_flat = torch.masked_select(student_hidden, attn_mask_reshaped) | ||
| teacher_hidden_selected_flat = torch.masked_select(teacher_hidden_selected, attn_mask_reshaped) | ||
| # we flatten the batch, to get the class scores & probabilities to the 2nd dimension | ||
| student_hidden_unbatched = student_hidden_flat.reshape(-1, student_hidden.shape[-1]) | ||
| teacher_hidden_unbatched = teacher_hidden_selected_flat.reshape(-1, student_hidden.shape[-1]) | ||
| else: | ||
| # we flatten the batch, to get the class scores & probabilities to the 2nd dimension | ||
| student_hidden_unbatched = student_hidden.reshape(-1, student_hidden.shape[-1]) | ||
| teacher_hidden_unbatched = teacher_hidden_selected.reshape(-1, student_hidden.shape[-1]) | ||
| similarity_or_distance_loss = torch.ones(student_hidden_unbatched.shape[0]) | ||
| return cosine_loss(student_hidden_unbatched, teacher_hidden_unbatched, similarity_or_distance_loss) | ||
| def _hidden_states_loss(self, | ||
| student_outputs: ModelOutput, | ||
| teacher_outputs: ModelOutput, | ||
| attn_mask: torch.LongTensor) -> torch.FloatTensor: | ||
| if hasattr(teacher_outputs, "hidden_states"): | ||
| # encoder-, or decoder-only | ||
| assert hasattr(student_outputs, "hidden_states"), "Student and teacher must be of the same type" | ||
| student_hidden = student_outputs.hidden_states | ||
| teacher_hidden = teacher_outputs.hidden_states | ||
| loss = self._loss_for_hidden_states(student_hidden, teacher_hidden, attn_mask) | ||
| elif hasattr(teacher_outputs, "encoder_hidden_states") and hasattr(teacher_outputs, "decoder_hidden_states"): | ||
| # encoder-decoder: sum the losses for both encoder and decoder states | ||
| assert hasattr(student_outputs, "encoder_hidden_states") \ | ||
| and hasattr(teacher_outputs, "decoder_hidden_states"), "Student and teacher must be of the same type" | ||
| student_encoder_hidden = student_outputs.encoder_hidden_states | ||
| teacher_encoder_hidden = teacher_outputs.encoder_hidden_states | ||
| student_decoder_hidden = student_outputs.decoder_hidden_states | ||
| teacher_decoder_hidden = teacher_outputs.decoder_hidden_states | ||
| loss = (self._loss_for_hidden_states(student_encoder_hidden, teacher_encoder_hidden, attn_mask) + | ||
| self._loss_for_hidden_states(student_decoder_hidden, teacher_decoder_hidden, attn_mask)) | ||
| else: | ||
| raise ValueError("Please initialize both teacher and student model with `output_hidden_states=True`") | ||
| return loss | ||
| def _compute_loss(self, | ||
| student_logits: torch.FloatTensor, | ||
| labels: torch.LongTensor, | ||
| inputs: Optional[Union[BatchEncoding, Dict[str, torch.Tensor]]] = None) -> torch.FloatTensor: | ||
| assert inputs is not None, "Distillation loss requires model inputs to be passed" | ||
| # output logits' loss | ||
| ce_loss = CrossEntropyLoss() | ||
| teacher_inputs = inspect.getfullargspec(self.teacher_model.forward).args | ||
| with torch.no_grad(): | ||
| teacher_outputs = self.teacher_model(**{k: v for k, v in inputs.items() if k in teacher_inputs}) | ||
| teacher_logits = teacher_outputs.logits | ||
| if self.restrict_loss_to_mask: | ||
| # pick only the predictions of tokens on the attended positions (i.e. ignore the others) | ||
| attn_mask_reshaped = inputs["attention_mask"].unsqueeze(-1).expand_as(student_logits).bool() | ||
| student_logits_flat = torch.masked_select(student_logits, attn_mask_reshaped) | ||
| student_logits_unbatched = student_logits_flat.reshape(-1, student_logits.shape[-1]) | ||
| # we flatten the batch, to get the class scores & probabilities to the 2nd dimension | ||
| teacher_logits_flat = torch.masked_select(teacher_logits, attn_mask_reshaped) | ||
| teacher_logits_unbatched = teacher_logits_flat.reshape(-1, student_logits.shape[-1]) | ||
| else: | ||
| # we flatten the batch, to get the class scores & probabilities to the 2nd dimension | ||
| student_logits_unbatched = student_logits.flatten(end_dim=1) | ||
| teacher_logits_unbatched = teacher_logits.flatten(end_dim=1) | ||
| distil_loss = ce_loss(log_softmax(student_logits_unbatched / self.temperature, dim=-1), | ||
| softmax(teacher_logits_unbatched / self.temperature, dim=-1)) * self.temperature ** 2 | ||
| distil_loss = self.logits_ce_loss_weight * distil_loss | ||
| # end output logits' loss | ||
| if self.add_hidden_states_loss: | ||
| # hidden states loss | ||
| student_inputs = inspect.getfullargspec(self.compatible_head_model.forward).args | ||
| student_outputs = self.compatible_head_model(**{k: v for k, v in inputs.items() if k in student_inputs}) | ||
| hidden_loss = self._hidden_states_loss(student_outputs, teacher_outputs, inputs["attention_mask"]) | ||
| hidden_loss_scaled = self.hidden_cossim_loss_weight * hidden_loss | ||
| distil_loss = distil_loss + hidden_loss_scaled | ||
| return distil_loss |
| Metadata-Version: 2.1 | ||
| Name: adaptor | ||
| Version: 0.1.6 | ||
| Version: 0.2.0 | ||
| Summary: Adaptor: Objective-centric Adaptation Framework for Language Models. | ||
@@ -5,0 +5,0 @@ Home-page: https://github.com/gaussalgo/adaptor |
@@ -32,2 +32,3 @@ .gitignore | ||
| adaptor/objectives/denoising.py | ||
| adaptor/objectives/distillation.py | ||
| adaptor/objectives/objective_base.py | ||
@@ -34,0 +35,0 @@ adaptor/objectives/question_answering.py |
@@ -153,3 +153,8 @@ import logging | ||
| """ | ||
| selected_head_model = self.trainable_models[str(inputs["oid"])] | ||
| try: | ||
| selected_head_model = self.trainable_models[str(inputs["oid"])] | ||
| except KeyError: | ||
| raise ValueError("Requesting inference with the objective having no registered head." | ||
| "If you are using `extra_eval_objectives`, " | ||
| "do not forget to fill in their `share_other_objective_head`.") | ||
| # include only correct inputs for a specific model | ||
@@ -156,0 +161,0 @@ list_of_model_specific_inputs = inspect.getfullargspec(selected_head_model.forward).args |
@@ -14,13 +14,25 @@ from typing import Dict, Iterable, Optional, Union | ||
| def _wordpiece_token_label_alignment(self, texts: Iterable[str], | ||
| labels: Iterable[str]) -> Iterable[Dict[str, torch.LongTensor]]: | ||
| def _wordpiece_token_label_alignment(self, | ||
| texts: Iterable[str], | ||
| labels: Iterable[str], | ||
| label_all_tokens: bool = True, | ||
| ignore_label_idx: int = -100) -> Iterable[Dict[str, torch.LongTensor]]: | ||
| """ | ||
| Decompose given space-separated labels and words into subword-aligned input ids and label ids, | ||
| Performs batching and collation and return resulting encodings. | ||
| :param texts: Sentence-level input texts. | ||
| :param labels: Sentence-level input labels, aligned with input words by spaces. | ||
| NOTE: be aware that due to the segmentation by spaces, tokenization might differ between the training | ||
| and inference for the models using space-including tokenizers, such as sentencepiece. | ||
| We tested this objective only with commonly-used Encoders (BERT, RoBERTa) utilizing pre-tokenized WPiece & BPE. | ||
| For an example of expected inputs, see tests/mock_data/supervised_texts.txt | ||
| and texts/mock_data/supervised_texts_token_labels.txt | ||
| :return: Aligned encodings. | ||
| :param texts: Sentence-level input texts. | ||
| :param labels: Sentence-level input labels, aligned with input words by spaces. | ||
| :param label_all_tokens: Whether to assign consistent label to all wordpieces of labeled tokens, | ||
| or only to the first wordpiece, giving `ignore_label_idx` to the following wordpieces. | ||
| :param ignore_label_idx: a label assigned to the wordpieces assigned no labels. | ||
| :return: Aligned, batched encodings. | ||
| """ | ||
@@ -30,2 +42,22 @@ collator = DataCollatorForTokenClassification(self.tokenizer, pad_to_multiple_of=8) | ||
| # special tokens identification: general heuristic | ||
| ids1 = self.tokenizer("X").input_ids | ||
| ids2 = self.tokenizer("Y").input_ids | ||
| special_bos_tokens = [] | ||
| for i in range(len(ids1)): | ||
| if ids1[i] == ids2[i]: | ||
| special_bos_tokens.append(ids1[i]) | ||
| else: | ||
| break | ||
| special_eos_tokens = [] | ||
| for i in range(1, len(ids1)): | ||
| if ids1[-i] == ids2[-i]: | ||
| special_eos_tokens.append(ids1[-i]) | ||
| else: | ||
| break | ||
| special_eos_tokens = list(reversed(special_eos_tokens)) | ||
| # per-sample iteration | ||
| for text, text_labels in zip(texts, labels): | ||
@@ -35,14 +67,2 @@ tokens = text.split() | ||
| tokenizer_encodings = self.tokenizer(text, truncation=True) | ||
| # attention mask is lang_module-specific | ||
| attention_mask = tokenizer_encodings.attention_mask | ||
| wpiece_ids = tokenizer_encodings.input_ids | ||
| wordpieces = self.tokenizer.batch_decode(wpiece_ids) | ||
| out_label_ids = [] | ||
| # next token lookup - avoid out-of-index, and exclude from token labels | ||
| tokens.append(wordpieces[-1]) | ||
| labels.append("O") | ||
| assert len(tokens) == len(labels), \ | ||
@@ -52,14 +72,34 @@ "A number of tokens in the first line is different than a number of labels. " \ | ||
| # assign current label to current wordpiece until the current_token is fully iterated-over | ||
| current_token, current_label = tokens.pop(0), labels.pop(0) # noqa F401: current_token unused | ||
| for wpiece_id, wpiece in zip(wpiece_ids, wordpieces): | ||
| next_token = tokens[0] | ||
| if next_token.startswith(wpiece): | ||
| # if the next token starts with a current wordpiece, move to the next token + label | ||
| current_token, current_label = tokens.pop(0), labels.pop(0) # noqa F401: current_token unused | ||
| out_label_ids.append(self.labels_map[current_label]) | ||
| tokens_ids = self.tokenizer(tokens, truncation=True, add_special_tokens=False).input_ids | ||
| batch_features.append({"input_ids": wpiece_ids, | ||
| "attention_mask": attention_mask, | ||
| "labels": out_label_ids}) | ||
| wpiece_ids = special_bos_tokens.copy() | ||
| # labels of BoS and EoS are always "other" | ||
| out_label_ids = [ignore_label_idx] * len(special_bos_tokens) | ||
| for token_ids, label in zip(tokens_ids, labels): | ||
| # chain the wordpieces without the special symbols for each token | ||
| wpiece_ids.extend(token_ids) | ||
| if label_all_tokens: | ||
| # label all wordpieces | ||
| out_label_ids.extend([self.labels_map[label]] * len(token_ids)) | ||
| else: | ||
| # label only the first wordpiece | ||
| out_label_ids.append(self.labels_map[label]) | ||
| # ignore the predictions over other token's wordpieces from the loss | ||
| out_label_ids.extend([ignore_label_idx] * (len(token_ids) - 1)) | ||
| out_label_ids.extend([ignore_label_idx] * len(special_eos_tokens)) | ||
| wpiece_ids.extend(special_eos_tokens.copy()) | ||
| assert len(out_label_ids) == len(wpiece_ids), "We found misaligned labels in sample: '%s'" % text | ||
| if self.tokenizer.model_max_length is None: | ||
| truncated_size = len(out_label_ids) | ||
| else: | ||
| truncated_size = min(self.tokenizer.model_max_length, len(out_label_ids)) | ||
| batch_features.append({"input_ids": wpiece_ids[:truncated_size], | ||
| "attention_mask": [1] * truncated_size, | ||
| "labels": out_label_ids[:truncated_size]}) | ||
| # maybe yield a batch | ||
@@ -66,0 +106,0 @@ if len(batch_features) == self.batch_size: |
@@ -363,3 +363,4 @@ import abc | ||
| def register_compatible_head_model(self, lang_module: LangModule, | ||
| def register_compatible_head_model(self, | ||
| lang_module: LangModule, | ||
| other_objective: Optional["Objective"] = None, | ||
@@ -366,0 +367,0 @@ objective_args_for_head_config: Optional[Dict[str, Any]] = None, |
+1
-1
| Metadata-Version: 2.1 | ||
| Name: adaptor | ||
| Version: 0.1.6 | ||
| Version: 0.2.0 | ||
| Summary: Adaptor: Objective-centric Adaptation Framework for Language Models. | ||
@@ -5,0 +5,0 @@ Home-page: https://github.com/gaussalgo/adaptor |
+1
-1
@@ -12,3 +12,3 @@ #!/usr/bin/env python | ||
| name="adaptor", | ||
| version='0.1.6', | ||
| version='0.2.0', | ||
| description="Adaptor: Objective-centric Adaptation Framework for Language Models.", | ||
@@ -15,0 +15,0 @@ long_description_content_type="text/markdown", |
@@ -12,7 +12,3 @@ from adaptor.lang_module import LangModule | ||
| unsup_target_domain_texts = "mock_data/domain_unsup.txt" | ||
| sup_target_domain_texts = "mock_data/supervised_texts.txt" | ||
| sup_target_domain_labels = "mock_data/supervised_texts_token_labels.txt" | ||
| def assert_module_objective_ok(lang_module: LangModule, objective: Objective, split: str = "train"): | ||
@@ -19,0 +15,0 @@ # dataset iteration test |
+9
-6
@@ -25,8 +25,11 @@ from adaptor.utils import AdaptationArguments, StoppingStrategy | ||
| "translation_multi": { | ||
| "model": "sshleifer/tiny-mbart", | ||
| "test_src_lang": "en_XX", | ||
| "test_tgt_lang": "cs_CZ"}, | ||
| "token_classification": "bert-base-multilingual-cased", | ||
| "sequence_classification": "bert-base-multilingual-cased", | ||
| "extractive_QA": "Unbabel/xlm-roberta-comet-small" | ||
| "model": "sshleifer/tiny-mbart", | ||
| "test_src_lang": "en_XX", | ||
| "test_tgt_lang": "cs_CZ" | ||
| }, | ||
| "token_classification": "bert-base-cased", | ||
| "sequence_classification": "bert-base-cased", | ||
| "extractive_QA": "Unbabel/xlm-roberta-comet-small", | ||
| "MLM": "bert-base-cased", | ||
| "MLM_student": "distilbert-base-cased" | ||
| } | ||
@@ -33,0 +36,0 @@ |
Alert delta unavailable
Currently unable to show alert delta for PyPI packages.
335498
4.02%44
2.33%3370
6.07%