PromptSMILES: Prompting for scaffold decoration and fragment linking in chemical language models
This library contains code to manipulate SMILES strings to facilitate iterative prompting to be coupled with a trained chemical language model (CLM) that uses SMILES notation.
Installation
The libary can be installed via pip
pip install promptsmiles
Or via obtaining a copy of this repo, promptsmiles requires RDKit.
git clone https://github.com/compsciencelab/PromptSMILES.git
cd PromptSMILES
pip install ./
Use
PromptSMILES is designed as a wrapper to CLM sampling that can accept a prompt (i.e., an initial string to begin autoregressive token generation). Therefore, it requires two callable functions, described later. PromptSMILES has 3 main classes, DeNovo (a dummy wrapper to make code consistent), ScaffoldDecorator, and FragmentLinker.
Scaffold Decoration
from promptsmiles import ScaffoldDecorator, FragmentLinker
SD = ScaffoldDecorator(
scaffold="N1(*)CCN(CC1)CCCCN(*)",
batch_size=64,
sample_fn=CLM.sampler,
evaluate_fn=CLM.evaluater,
batch_prompts=False,
optimize_prompts=True,
shuffle=True,
return_all=False,
)
smiles = SD.sample(batch_size=3, return_all=True)
Fragment linking / scaffold hopping
FL = FragmentLinker(
fragments=["N1(*)CCNCC1", "C1CC1(*)"],
batch_size=64,
sample_fn=CLM.sampler,
evaluate_fn=CLM.evaluater,
batch_prompts=False,
optimize_prompts=True,
shuffle=True,
scan=False,
return_all=False,
)
smiles = FL.sample(batch_size=3)
Required chemical language model functions
Notice the callable functions required CLM.sampler and CLM.evaluater. The first is a function that samples from the CLM given a prompt.
def CLM_sampler(prompt: Union[str, list[str]], batch_size: int):
"""
Input: Must have a prompt and batch_size argument.
Output: SMILES [list]
"""
return smiles
Note: For a more efficient implementation, prompt should accept a list of prompts equal to batch_size and batch_prompts
should be set to True
in the promptsmiles class used.
The second is a function that evaluates the NLL of a list of SMILES
def CLM_evaluater(smiles: list[str]):
"""
Input: A list of SMILES
Output: NLLs [list, np.array, torch.tensor](CPU w.o. gradient)
"""
return nlls