Skip to content

Why not timesteps[final_timestep] if you want to decode x_0? #112

@jsrdcht

Description

@jsrdcht

There appears to be a discrepancy between the paper and the code. The paper suggests an attempt to decode the final_timestep from the mid_timestep, but the code performs decoding at the mid_timestep in both instances. My expected behavior is

latent_model_input = latents
        latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timesteps[mid_timestep])
        noise_pred = self.unet(
            latent_model_input,
            timesteps[mid_timestep],
            prompt_embeds,
            added_cond_kwargs=unet_added_conditions,
        ).sample
        pred_original_sample = self.noise_scheduler.step(noise_pred, timesteps[**final_timestep**], latents).pred_original_sample.to(self.weight_dtype)

Code in this repo:

for i, t in enumerate(timesteps[:mid_timestep]):
            with torch.no_grad():
                latent_model_input = latents
                latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    prompt_embeds,
                    added_cond_kwargs=unet_added_conditions,
                ).sample
                latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
        
        latent_model_input = latents
        latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timesteps[mid_timestep])
        noise_pred = self.unet(
            latent_model_input,
            timesteps[mid_timestep],
            prompt_embeds,
            added_cond_kwargs=unet_added_conditions,
        ).sample
        pred_original_sample = self.noise_scheduler.step(noise_pred, timesteps[mid_timestep], latents).pred_original_sample.to(self.weight_dtype)
        
        pred_original_sample = 1 / self.vae.config.scaling_factor * pred_original_sample
        image = self.vae.decode(pred_original_sample.to(self.weight_dtype)).sample
        image = (image / 2 + 0.5).clamp(0, 1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions