
Security News
New React Server Components Vulnerabilities: DoS and Source Code Exposure
New DoS and source code exposure bugs in React Server Components and Next.js: whatâs affected and how to update safely.
imageddpo
Advanced tools
The equivalent of SDImg2ImgPipeline for DDPO: modifying DDPOTrainer to support image inputs in addition to text prompts
The equivalent of SDImg2ImgPipeline for DDPO: modifying DDPOTrainer to support image inputs in addition to text prompts.
git clone https://github.com/hectorastrom/imageddpo.git
cd imageddpo
pip install -e .
pip install imageddpo
from imageddpo import ImageDDPOTrainer, I2IDDPOStableDiffusionPipeline
from trl import DDPOConfig
Configure the trainer using DDPOConfig from the TRL library:
config = DDPOConfig(
sample_num_steps=50,
sample_guidance_scale=7.5,
sample_eta=1.0,
train_batch_size=4,
train_use_8bit_adam=True,
# ... other DDPOConfig parameters
)
# Initialize the pipeline
pipeline = I2IDDPOStableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
# ... pipeline initialization parameters
)
# Initialize the trainer
trainer = ImageDDPOTrainer(
config=config,
model=pipeline.unet,
ref_model=None,
accelerator=None,
prompt_fn=your_prompt_function,
reward_fn=your_reward_function,
noise_strength=0.2, # ImageDDPO specific parameter
# ... other trainer parameters
)
For a complete example, please refer to the gaussian glasses repo and website.
There, you will see:
rl/rl_trainer.pyrl/reward.pyc = (text_prompt, input_image)input_imaget=1000 -> t=0, but from t=(1000 * noise_strength) -> t=0
noise_strength=0.4 we're denoising from t=400 -> t=0Img2ImgDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline)
noise_strength := s ranging from [0, 1](init_images, text_prompts, metadata)
(init_images, text_prompts,metadatas)save_state monkey patch
use_lora in I2I pipeline (which inherits from
DefaultDDPOStableDiffusionPipeline) which uses default settings on UNet
r=4, lora_alpha=4, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]Treat the input image as part of the environment state
prompt_fn was extended to return (prompt, image, metadata) instead of just text.x_0 derived from the image, not pure Gaussian noise.**Switch from full text-to-image sampling to image-to * Runs only a suffix of the diffusion schedule starting from a chosen noise level.ne` were introduced that:
Accept pre-computed latents from the VAE as input.
Run only a suffix of the diffus * Run only a suffix of the diffusion schedule, starting from a chosen noise level instead of from pure noise.via noise_strength / starting_step_ratio**
noise_strength in the trainer determines:
t_start you add noise at when constructing x_t from the encoded image.starting_step_ratio.This couples âhow corrupted the image is when the policy starts actingâ with âhow many denoising actions occur,â making the MDP horizon explicit and tunable.
Generate trajectories (x_t, x_{tâ1}) + log-probs instead of just final images
The pipeline was modified to:
x_t along the denoising path, andscheduler_step that returns both new latents and a per-step log-prob._generate_samples now returns:
latents[:, :-1], next_latents[:, 1:] and aligned timesteps,Extend the samplingâreward interface to operate on images
compute_rewards now effectively receives â(generated image, original prompt, original image, metadata)â tuples so downstream vision models can score the image-conditioned generations.Keep DDPOâs RL machinery unchanged but re-wired to the image pipeline
DDPOTrainer.(timesteps, latents, next_latents, log_probs, advantages) but these now correspond to image-conditioned rollouts instead of purely text-conditioned ones.Support both âimage + promptâ and âimage-onlyâ conditioning
The pipeline handles:
guidance_scale > 1 (uncond + text embeddings).guidance_scale <= 1 by running only the unconditional embedding path (no extra CFG forward pass), so the policy can optimize purely w.r.t. the image.Adjust training step semantics to keep logging and epochs meaningful
ImageDDPOTrainer.step was overridden to:
global_step by the number of collected samples so WandB x-axes reflect âdata processedâ rather than only âoptimizer steps,â while still delegating actual weight updates to the original DDPO training loop.Custom, device-safe _get_variance and scheduler_step
The stock TRL _get_variance assumes alphas_cumprod lives on CPU and indexes with timestep.cpu(), which clashed with Accelerate moving the scheduler to CUDA.
A custom _get_variance and scheduler_step were implemented that:
alphas_cumprod and final_alpha_cumprod onto that same device before gather, andDeterministic vs stochastic scheduler steps
The scheduler step explicitly branches:
eta == 0 or variance is effectively zero (log-prob set to zero because there is no stochastic action),eta > 0, where the Gaussian log-prob of x_{tâ1} is computed.This is important for DDPO, since only stochastic steps should contribute meaningful policy log-probs.
Log-prob reduction uses mean over pixels, not sum
Strict latent contract for I2I (no text-only fallback)
ImageDDPOTrainer always expects image-derived latents.CFG and unconditional path details
negative_prompt_embeds) and runs the UNet only once per step, avoiding wasted compute when there is no textual guidance.guidance_rescale.Timesteps alignment with the truncated schedule
timesteps from scheduler.timesteps after the I2I call, ensuring that timesteps.shape[1] == log_probs.shape[1] and corresponds to the exact steps where actions were taken.Global step accounting for logging
global_step is incremented by the number of samples collected per epoch (batch size Ă num batches Ă num processes) on top of inner training increments.All credit goes to the DDPO Implementation from HuggingFace TRL (now deprecated) and the DDPO paper.
I also found Dr. Tanishq Abraham's blog to be incredibly helpful.
FAQs
The equivalent of SDImg2ImgPipeline for DDPO: modifying DDPOTrainer to support image inputs in addition to text prompts
We found that imageddpo demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago. It has 1 open source maintainer collaborating on the project.
Did you know?

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Security News
New DoS and source code exposure bugs in React Server Components and Next.js: whatâs affected and how to update safely.

Security News
Socket CEO Feross Aboukhadijeh joins Software Engineering Daily to discuss modern software supply chain attacks and rising AI-driven security risks.

Security News
GitHub has revoked npm classic tokens for publishing; maintainers must migrate, but OpenJS warns OIDC trusted publishing still has risky gaps for critical projects.