Transformers-Re
A Regular Expression constraint for Language Models of transformers. With this module, you can force the LLMs to
generate following your regex. Using regex in tokens and tensors are also implemented in this project.
本项目支持通过正则表达式控制LLMs输出,同时还实现了通过正则表达式抽取token或tensor。
Installation
pip install transformers-re
RegexLogitsProcessor
A regex constraint logits processor for transformers.
__init__:
- tokenizer: transformers.Tokenizer
- prompt: the prompt input into model.
- pattern: regex pattern.
- num_proc: the number of processors to process regex match.
- fail_strategy: default: 'eos'. The strategy when no token can use. It can be:
- List[int]: token ids when no token can use.
- 'eos': automatically choose the tokenizer.eos_token_id
- debug: default: False. Control the debug information output.
- guess_token: default: False. Try to guess the next token to accelerate generating. Guessing a wrong token will slow down the generation.
Attributes:
- generated_text(str): Cached text. Generation can be restored from it to prevent running error of regex mismatch
- match(regex.Match): The Match object of generated text and pattern
Usage example:
>>> with RegexLogitsProcessor(tokenizer, prompt, pattern) as regex_logits_processor:
>>> model.generate(logits_processor=[regex_logits_processor])
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers_re import RegexPrefixProcessor
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen-14B-Chat-Int4')
model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen-14B-Chat-Int4').eval().to('cuda')
prompt = "<|im_start|>user\n请帮我写一首表达思乡之情的诗<|im_end|>\n<|im_start|>assistant\n"
pattern = r"一[\u4e00-\u9fa5]{4},键[\u4e00-\u9fa5]{4}。三[\u4e00-\u9fa5]{4},连[\u4e00-\u9fa5]{4}。"
with RegexLogitsProcessor(tokenizer, prompt, pattern, num_proc=16, debug=True,
fail_strategy=tokenizer.encode("<|im_end|><|endoftext|>")) as regex_logits_processor:
input_ids = tokenizer(prompt, return_tensors="pt").to('cuda')["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=40, logits_processor=[regex_logits_processor])
print(tokenizer.decode(outputs[0]))
RegexPrefixProcessor
A regex prefix constraint for transformers.
__init__:
- tokenizer: transformers.Tokenizer
- prompt: the prompt input into model.
- pattern: regex pattern.
- num_proc: the number of processors to process regex match.
- fail_strategy: default: 'eos'. The strategy when no token can use. It can be:
- List[int]: token ids when no token can use.
- 'eos': automatically choose the tokenizer.eos_token_id
- debug: default: False. Control the debug information output.
Attributes:
- generated_text(str): Cached text. Generation can be restored from it to prevent running error of regex mismatch
- match(regex.Match): The Match object of generated text and pattern
Usage example:
>>> with RegexPrefixProcessor(tokenizer, prompt, pattern) as regex_prefix_processor:
>>> model.generate(prefix_allowed_tokens_fn=regex_prefix_processor)
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers_re import RegexPrefixProcessor
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained('YeungNLP/firefly-1b4')
model = AutoModelForCausalLM.from_pretrained('YeungNLP/firefly-1b4').eval().to('cuda')
prompt = "<s>请帮我写一首表达思乡之情的诗</s></s>"
pattern = r"一[\u4e00-\u9fa5]{4},键[\u4e00-\u9fa5]{4}。三[\u4e00-\u9fa5]{4},连[\u4e00-\u9fa5]{4}。"
with RegexPrefixProcessor(tokenizer, prompt, pattern, num_proc=16) as regex_prefix_processor:
input_ids = tokenizer(prompt, return_tensors="pt").to('cuda')["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=20, prefix_allowed_tokens_fn=regex_prefix_processor)
print(tokenizer.decode(outputs[0]))
TokenizedPattern
A regex pattern compiled with a tokenizer and applies to token or tensors.
__init__:
- tokenizer: The using tokenizer
- pattern: A regex pattern to match strings
- strategy: Since the match string may cut off some tokens. It decides what to do when truncation happends.
- 'expand': expand the match span to the minimum token span covers the string span.
- 'shrink': shrink the match span to the maximum token span be covered by the string span.
- 'error': do nothing but raise an error.
- token_mapping_func: A function where f(tokenizer, str) -> List[Span]. The list length should equals to the
token length, in which each span corresponding to its token's character span in string.
Methods:
- match(self, string, pos=None, endpos=None) -> TokenizedMatch
- search(self, string, pos=None, endpos=None) -> TokenizedMatch
TokenizedMatch
The regex match result corresponding to a TokenizedPattern and a string.
Methods:
- span(self, index=0, of_token=False)
- start(self, index=0, of_token=False)
- end(self, index=0, of_token=False)
- group(self, index=0, of_token=False)
- mask(self, index=0)
- masked_select(self, tensor, index=0, dim=-1)
Usage example:
s = "This is an example text"
token_ids = tokenizer(s, return_tensors="pt").input_ids
tokens = [tokenizer.decode(t) for t in tokenizer(s).input_ids]
print(token_ids, tokens, sep="\n")
"""
tensor([[3180, 579, 593, 3392, 2895]])
['This', ' is', ' an', ' example', ' text']
"""
from transformers_re import TokenizedPattern
a, b = TokenizedPattern(tokenizer, "ample(.*)", "expand").search(s).span()
print(s[a:b])
a, b = TokenizedPattern(tokenizer, "ample(.*)", "shrink").search(s).span()
print(s[a:b])
a, b = TokenizedPattern(tokenizer, "ample(.*)", "expand").search(s).span(of_token=True)
print(token_ids[0, a:b])
a, b = TokenizedPattern(tokenizer, "ample(.*)", "expand").search(s).span(index=1, of_token=True)
print(token_ids[0, a:b])
mask = TokenizedPattern(tokenizer, "ample(.*)", "expand").search(s).mask()
print(mask)
import torch
torch.manual_seed(42)
x = torch.randint(0, 10, (5, 5))
"""
tensor([[2, 7, 6, 4, 6],
[5, 0, 4, 0, 3],
[8, 4, 0, 4, 1],
[2, 5, 5, 7, 6],
[9, 6, 3, 1, 9]])
"""
print(TokenizedPattern(tokenizer, "ample(.*)", "expand").search(s).masked_select(x, dim=0))
"""
tensor([[2, 5, 5, 7, 6],
[9, 6, 3, 1, 9]])
"""
print(TokenizedPattern(tokenizer, "ample(.*)", "expand").search(s).masked_select(x, dim=1))
"""
tensor([[4, 6],
[0, 3],
[4, 1],
[7, 6],
[1, 9]])
"""
Limitation
THe main bottleneck of regex constraint generation is the match processing of regex.
Although we can use multiprocess in this project to accelerate the procedure, the
time consumption proportion of regex match is still too high.
The time cost can be reduced more if incremental matching of regex or some GPU regex
engine can be implement in this project.