OmiGen

2025-02-08 16:13:32 129

1. 使用方法

1.1 Interface

To handle some complex tasks, image generation models are becoming increasingly sophisticated, leading to more and more cumbersome workflows. Existing image generation models like SD and Flux require loading many additional network modules (such as ControlNet, IP-Adapter, Reference-Net) and extra preprocessing steps (e.g., face detection, pose detection, image cropping) to generate a satisfactory image. This complex workflow is not user-friendly. We believe that future image generation models should be simpler, generating various images directly through instructions, similar to how GPT works in language generation.

Therefore, we propose OmniGen, a model capable of handling various image generation tasks within a single framework. The goal of OmniGen is to complete various image generation tasks without relying on any additional components or image preprocessing steps. OmniGen supports tasks including text-to-image generation, image editing, subject-driven image generation, and classical vision tasks, among others. More capabilities can be found in our examples. We provide inference code so you can explore more unknown functionalities yourself.

Install

git clone https://github.com/staoxiao/OmniGen.gitcd OmniGen
pip install -e .

Generate Images

You can use the following code to generate images:

from OmniGen import OmniGenPipeline

pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")# Text to Imageimages = pipe(
    prompt="A curly-haired man in a red shirt is drinking tea.", 
    height=1024, 
    width=1024, 
    guidance_scale=2.5,
    seed=0,
)
images[0].save("example_t2i.png")  # save output PIL Image# Multi-modal to Image# In prompt, we use the placeholder to represent the image. The image placeholder should be in the format of <img><|image_*|></img># You can add multiple images in the input_images. Please ensure that each image has its placeholder. For example, for the list input_images [img1_path, img2_path], the prompt needs to have two placeholders: <img><|image_1|></img>, <img><|image_2|></img>.images = pipe(
    prompt="A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>."
    input_images=["./imgs/test_cases/two_man.jpg"]
    height=1024, 
    width=1024,
    separate_cfg_infer=False,  # if OOM, you can set separate_cfg_infer=True 
    guidance_scale=2.5, 
    img_guidance_scale=1.6)
images[0].save("example_ti2i.png")  # save output PIL image

Some important arguments:

  • guidance_scale: The strength of the guidance. Based on our experience, it is usually best to set it between 2 and 3. The higher the value, the more similar the generated image will be to the prompt. If the image appears oversaturated, please reduce the scale.

  • height and width: The height and width of the generated image. The default value is 1024x1024. OmniGen support any size, but these number must be divisible by 16.

  • num_inference_steps: The number of steps to take in the diffusion process. The higher the value, the more detailed the generated image will be.

  • separate_cfg_infer: Whether to use separate inference process for CFG guidance. If set to True, memory cost will be lower but the generation speed will be slower. Default is False.

  • use_kv_cache: Whether to use key-value cache. Default is True.

  • seed: The seed for random number generator.

More examples please refer to inference.ipynb

Input data

OmniGen can accept multi-modal input data. Specifically, you should pass two arguments: prompt and input_images. For text to image generation, you can pass a string as prompt, or pass a list of strings as prompt to generate multiple images.

For multi-modal to image generation, you should pass a string as prompt, and a list of image paths as input_images. The placeholder in the prompt should be in the format of <img><|image_*|></img>. For example, if you want to generate an image with a person holding a bouquet of flowers, you can pass the following prompt:

prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is <img><|image_1|></img>."
input_images = ["./imgs/test_cases/liuyifei.png"]

The placeholder <|image_1|> will be replaced by the image at input_images[0], i.e., ./imgs/test_cases/liuyifei.png.

If you want to generate multiple images, you can pass a list of prompts and a list of image paths. For example:

prompt = ["A woman holds a bouquet of flowers and faces the camera.", "A woman holds a bouquet of flowers and faces the camera. Thw woman is <img><|image_1|></img>."]
input_images = [[], ["./imgs/test_cases/liuyifei.png"]]

Gradio Demo

We have constructed a online demo in Huggingface.

For the local gradio demo, you can run with the following command:

python app.py

Tips

  • OOM issue: If you encounter OOM issue, you can try to set separate_cfg_infer=True. This will reduce the memory usage but increase the generation latecy. You also can reduce the size of the image, e.g., height=768, width=512.

  • Oversaturated: If the image appears oversaturated, please reduce the guidance_scale.

  • Not match the prompt: If the image does not match the prompt, please try to increase the guidance_scale.

  • Low-quality: More detailed prompt will lead to better results. Besides, larger size of the image (height and width) will also help.

  • Animate Style: If the genereate images is in animate style, you can try to add photo to the prompt`.

  • Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image.

  • For image editing tasks, we recommend placing the image before the editing instruction. For example, use <img><|image_1|></img> remove suit, rather than remove suit <img><|image_1|></img>.

创建时间:2025-02-08 09:40 删除 编辑

1.2 使用方法

OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, and image-conditioned generation.

For multi-modal to image generation, you should pass a string as prompt, and a list of image paths as inputimages. The placeholder in the prompt should be in the format of <|image*|> (for the first image, the placeholder is <|image_1|>. for the second image, the the placeholder is <|image_2|>). For example, use an image of a woman to generate a new image: prompt = "A woman holds a bouquet of flowers and faces the camera. Thw woman is <|image_1|>."

Tips:

Oversaturated: If the image appears oversaturated, please reduce the guidance_scale. Low-quality: More detailed prompt will lead to better results. Animate Style: If the genereate images is in animate style, you can try to add photo to the prompt`. Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image. For image editing tasks, we recommend placing the image before the editing instruction. For example, use <|image_1|> remove suit, rather than remove suit <|image_1|>.

1.3 案例

We introduce some abilities of OmniGen in this ipynb, including text-to-image, identity-preserving generation, image-conditional generation, and so on.

Some tips for generation:

For out of memory or time cost, you can set offload_model=True or refer to ./docs/inference.md#requiremented-resources to select a appropriate setting. If the inference time is too long when input multiple images, you can reduce the max_input_image_size. More detaild pleae refer to ./docs/inference.md#requiremented-resources Oversaturated: If the image appears oversaturated, please reduce the guidance_scale. Not match the prompt: If the image does not match the prompt, please try to increase the guidance_scale. Low-quality: More detailed prompt will lead to better results. Animate Style: If the genereate images is in animate style, you can try to add photo to the prompt`. Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image. For image editing tasks, we recommend placing the image before the editing instruction. For example, use <|image_1|> remove suit, rather than remove suit <|image_1|>. For image editing task and controlnet task, we recommend to set the height and width of output image as the same as input image. For example, if you want to edit a 512x512 image, you should set the height and width of output image as 512x512. You also can set the use_input_image_size_as_output to automatically set the height and width of output image as the same as input image.

import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0' # select a gpu to run OmniGen

os.environ['HF_HUB_CACHE'] = 'path_to_save_downloaded_model'

from OmniGen import OmniGenPipeline

pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")

  1. Text to Image Here are some examples for text to image generation:

Text to Image

prompt = ["A vintage camera placed on the ground, ejecting a swirling cloud of Polaroid-style photographs into the air. The photos, showing landscapes, wildlife, and travel scenes, seem to defy gravity, floating upward in a vortex of motion. The camera emits a glowing, smoky light from within, enhancing the magical, surreal atmosphere. The dark background contrasts with the illuminated photos and camera, creating a dreamlike, nostalgic scene filled with vibrant colors and dynamic movement. Scattered photos are visible on the ground, further contributing to the idea of an explosion of captured memories.", "A curly-haired man in a red shirt is drinking tea.", ] for i in range(len(prompt)): images = pipe( prompt=prompt[i], #In fact, you also can pass the entire prompt list here, but it will take more memory cost to generate all images. height=1024, width=1024, guidance_scale=2.5, separate_cfg_infer=False, seed=0, )

images[0].save("i.png")

images[0].show()
  1. Subject-driven Generation or Identity-Preserving Generation You can input an image containing a specific object(eg., human, animal or others), and prompt model to generate a new image based on given object. Different form previous work, OmniGen don't need to detect and crop the object using other models (e.g., segment face is needed in InstandID and PULID). OmniGen will automatically find the specific object in the image and generate image.

What'more, OmniGen can process the input images consist of multi objects, our model can automatically identify objects in the image through descriptive instructions, e.g., the right man in <|image_1|>, the woman wearing pink clothes in <|image_2|>.

OmniGen also can extract multiple objects from multple images to generate a new image.

from PIL import Image

max_input_image_size = 1024 # you can reduce this size to speed-up the inference and reduce the memory usage

prompt="The woman in <|image_1|> waves her hand happily in the crowd" input_images=["./imgs/test_cases/zhang.png"] images = pipe( prompt=prompt, input_images=input_images, height=1024, width=1024, guidance_scale=2.5, img_guidance_scale=1.8,
seed=42 ) print("input_image: ") for img in input_images: Image.open(img).show() print("output:") images[0].show()

# You can pass multi objects from multiple input images

prompt = "Two woman are raising fried chicken legs in a bar. A woman is <|image_1|>. Another woman is <|image_2|>." # a1 input_images = ["./imgs/test_cases/mckenna.jpg", "./imgs/test_cases/Amanda.jpg"]
images = pipe( prompt=prompt, input_images=input_images, height=1024, width=1024, guidance_scale=2.5, img_guidance_scale=1.8, max_input_image_size=max_input_image_size, seed=168 ) print(f"input prompt: {prompt}ninput image:") for img in input_images: Image.open(img).show() print("output:") images[0].show()

# For input images consist of multi objects, our model can automatically identify objects in the image through descriptive instructions.

# e.g., the right man in <|image_1|>, the woman wearing pink clothes in <|image_2|>.

prompt="A man in a black shirt is reading a book. The man is the right man in <|image_1|>." input_images=["./imgs/test_cases/two_man.jpg"] images = pipe( prompt=prompt, input_images=input_images,
height=1024, width=1024, guidance_scale=2.5, img_guidance_scale=1.6, seed=0, )

images[0].save("tii2i.png")

print(f"input prompt: {prompt}ninput image:") for img in input_images: Image.open(img).show() print("output:") images[0].show()

# You also can extract multiple objects from multple images to generate a new image

# If you don't describe the clothes in the prompt, the model will tend to retain the clothing from the original image.

prompt="A man is sitting in the library reading a book, while a woman wearing white shirt next to him is wearing headphone. The man who is reading is the one wearing red sweater in <|image_1|>. The woman wearing headphones is the right woman wearing suit in <|image_2|>." input_images=["./imgs/test_cases/turing.png", "./imgs/test_cases/lecun.png"] images = pipe( prompt=prompt, input_images=input_images,
height=1024, width=1024, guidance_scale=2.5, img_guidance_scale=1.8, max_input_image_size=max_input_image_size, seed=2)

images[0].save("tii2i.png")

print(f"input prompt: {prompt}ninput image:") for img in input_images: Image.open(img).show() print("output:") images[0].show()

prompt="A man and a short-haired woman with a wrinkled face are standing in front of a bookshelf in a library. The man is the man in the middle of <|image_1|>, and the woman is oldest woman in <|image_2|>" input_images=["./imgs/test_cases/1.jpg", "./imgs/test_cases/2.jpg"] images = pipe( prompt=prompt, input_images=input_images,
height=1024, width=1024, guidance_scale=2.5, img_guidance_scale=1.6, max_input_image_size=max_input_image_size, seed=60)

images[0].save("tii2i.png")

print(f"input prompt: {prompt}ninput image:") for img in input_images: Image.open(img).show() print("output:") images[0].show()

prompt="A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <|image_1|>. The woman is the woman on the left of <|image_2|>" input_images=["./imgs/test_cases/3.jpg", "./imgs/test_cases/4.jpg"] images = pipe( prompt=prompt, input_images=input_images,
height=1024, width=1024, guidance_scale=2.5, img_guidance_scale=1.6, offload_model=False, # If OOM(out of memory), you can set offload_model=True max_input_image_size=max_input_image_size, seed=66)

images[0].save("tii2i.png")

print(f"input prompt: {prompt}ninput image:") for img in input_images: Image.open(img).show() print("output:") images[0].show()

prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <|image_1|>. The long-sleeve blouse and a pleated skirt are <|image_2|>." input_images=["./imgs/demo_cases/emma.jpeg", "./imgs/demo_cases/dress.jpg"] images = pipe( prompt=prompt, input_images=input_images,
height=1024, width=1024, guidance_scale=2.5, img_guidance_scale=1.6, offload_model=False, # If OOM(out of memory), you can set offload_model=True max_input_image_size=max_input_image_size, seed=666)

images[0].save("tii2i.png")

print(f"input prompt: {prompt}ninput image:") for img in input_images: Image.open(img).show() print("output:") images[0].show() 3. Image-conditional Generation For this task, a representative work is ControlNet. ControlNet requires using other detectors to detect conditions in the input image and then loading the corresponding modules for inference. Unlike ControlNet, OmniGen can complete both tasks (condition extraction and condition-based generation) within a single model and can even achieve this in one step (skipping the condition extraction step).

from PIL import Image

OmniGen can handle some classical CV tasks

prompt = "Generate the depth map for this image: <|image_1|>."

prompt = "Detect the skeleton of human in this image: <|image_1|>." input_images = ["./imgs/test_cases/control.jpg"] images = pipe( prompt=prompt, input_images=input_images,
height=1024, width=1024, guidance_scale=2.5, img_guidance_scale=1.6, separate_cfg_infer=False, # If OOM(out of memory), you can set separate_cfg_infer=True seed=0)

images[0].save("tii2i.png")

print(f"input prompt: {prompt}ninput image:") for img in input_images: Image.open(img).show() print("output:") images[0].show()

Generate images based on extracted condition

prompt = "Generate a new photo using the following picture and text as conditions: <|image_1|>n An elderly man wearing gold-framed glasses stands dignified in front of an elegant villa. His gray hair is neatly combed, and his hands rest in the pockets of his dark trousers. He is dressed warmly in a fitted coat over a sweater. The classic villa behind him features ivy-covered walls and large bay windows." input_images = ["./imgs/test_cases/pose.png"] images = pipe( prompt=prompt, input_images=input_images,
height=1024, width=1024, guidance_scale=2.5, img_guidance_scale=1.6, seed=0)

images[0].save("tii2i.png")

print(f"input prompt: {prompt}ninput image:") for img in input_images: Image.open(img).show() print("output:") images[0].show()

More simple method to generate images based on image-condition within one step

prompt = "Following the depth mapping of this image <|image_1|>, generate a new image: An elderly man wearing a blue hat and gold-framed glasses stands dignified in front of an elegant villa. His gray hair is neatly combed, and his hands rest in the pockets of his dark trousers. He is dressed warmly in a fitted coat over a sweater. The classic villa behind him features ivy-covered walls and large bay windows."

prompt = "Following the human pose of this image <|image_1|>, generate a new photo: An elderly man wearing a gold-framed glasses stands dignified in front of an elegant villa. His gray hair is neatly combed, and his hands rest in the pockets of his dark trousers. He is dressed warmly in a fitted coat over a sweater. The classic villa behind him features ivy-covered walls and large bay windows." input_images = ["./imgs/test_cases/control.jpg"] images = pipe( prompt=prompt, input_images=input_images,
height=1024, width=1024, guidance_scale=2.5, img_guidance_scale=1.6, seed=0)

images[0].save("tii2i.png")

print(f"input prompt: {prompt}ninput image:") for img in input_images: Image.open(img).show() print("output:") images[0].show() 4. Potential Reasoning Capability OmniGen demonstrates a certain level of reasoning capability. It can locate objects based on implicit descriptions rather than explicit instructions, such as identifying something drinkable. This feature could be useful in interesting scenarios, such as robotics.

from PIL import Image

prompt = "<|image_1|> What item can be used to see the current time? Please remove it." input_images = ["./imgs/test_cases/watch.jpg"] images = pipe( prompt=prompt, input_images=input_images, height=1024, width=1024, guidance_scale=2.5, img_guidance_scale=1.6, seed=0) print(f"input prompt: {prompt}ninput image:") for img in input_images: Image.open(img).show() print("output:") images[0].show() 5. Other tasks In addition to the above capabilities, OmniGen also offers many other features, such as image editing, denoising, and inpainting. Feel free to explore OmniGen's abilities while using it.

If you come up with new tasks, you can also fine-tune OmniGen to give it new capabilities. Fine-tuning is very straightforward, and you can refer to our docs.

Have fun!

1.4 使用案例

Some tips for generation:

For out of memory or time cost, you can set offload_model=True or refer to ./docs/inference.md#requiremented-resources to select a appropriate setting. If the inference time is too long when input multiple images, you can reduce the max_input_image_size. More detaild pleae refer to ./docs/inference.md#requiremented-resources Oversaturated: If the image appears oversaturated, please reduce the guidance_scale. Not match the prompt: If the image does not match the prompt, please try to increase the guidance_scale. Low-quality: More detailed prompt will lead to better results. Animate Style: If the genereate images is in animate style, you can try to add photo to the prompt`. Edit generated image. If you generate a image by omnigen and then want to edit it, you cannot use the same seed to edit this image. For example, use seed=0 to generate image, and should use seed=1 to edit this image. For image editing tasks, we recommend placing the image before the editing instruction. For example, use <|image_1|> remove suit, rather than remove suit <|image_1|>. For image editing task and controlnet task, we recommend to set the height and width of output image as the same as input image. For example, if you want to edit a 512x512 image, you should set the height and width of output image as 512x512. You also can set the use_input_image_size_as_output to automatically set the height and width of output image as the same as input image. import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0' # select a gpu to run OmniGen

os.environ['HF_HUB_CACHE'] = 'path_to_save_downloaded_model'

from OmniGen import OmniGenPipeline

pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1") Loading safetensors

Text to Image

prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD." images = pipe( prompt=prompt, height=1024, width=1024, guidance_scale=3, seed=111 ) images[0].save("./imgs/demo_cases/t2i_woman_with_book.png") images[0].show() 0%| | 0/50 [00:00<?, ?it/s] 100%|██████████| 50/50 [00:31<00:00, 1.57it/s]

Image Editing: Our model can perform multiple editing commands simultaneously.

from PIL import Image prompt="<|image_1|>n Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola." input_images=["./imgs/demo_cases/t2i_woman_with_book.png"] images = pipe( prompt=prompt, input_images=input_images, height=1024, width=1024, guidance_scale=2.5, img_guidance_scale=1.6, seed=222) # NOTE: If you want to edit an image generated by OmniGen's text-to-image, you must use a seed that is different from the one used to generate the original image. images[0].save("./imgs/demo_cases/edit.png") print("input_image: ") for img in input_images: Image.open(img).show() print("output: ") images[0].show() 100%|██████████| 50/50 [01:06<00:00, 1.34s/it] input_image:

output:

Reasoning Ability: Our model demonstrates reasoning capabilities in response to images and commands.

from PIL import Image prompt="If the woman is thirsty, what should she take? Find it in the image and highlight it in blue. <|image_1|>" input_images=["./imgs/demo_cases/edit.png"] images = pipe( prompt=prompt, input_images=input_images, height=1024, width=1024, guidance_scale=2, img_guidance_scale=1.6) images[0].save("./imgs/demo_cases/reasoning.png") print("input_image: ") for img in input_images: Image.open(img).show() print("output:") images[0].show() 0%| | 0/50 [00:00<?, ?it/s] 100%|██████████| 50/50 [01:31<00:00, 1.82s/it] input_image:

output:

Human Skeleton Detection: Our model also possesses high-level CV task capabilities, enabling accurate human skeleton recognition.

from PIL import Image prompt="Detect the skeleton of human in this image: <|image_1|>" input_images=["./imgs/demo_cases/edit.png"] images = pipe( prompt=prompt, input_images=input_images, height=1024, width=1024, guidance_scale=2, img_guidance_scale=1.6, seed=333) images[0].save("./imgs/demo_cases/skeletal.png") print("input_image: ") for img in input_images: Image.open(img).show() print("output:") images[0].show() 54%|█████▍ | 27/50 [00:50<00:41, 1.79s/it] 100%|██████████| 50/50 [01:31<00:00, 1.82s/it] input_image:

output:

Conditional Generation: Our model is capable of visual condition image generation.

from PIL import Image prompt="Generate a new photo using the following picture and text as conditions: <|image_1|>n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him." input_images=["./imgs/demo_cases/skeletal.png"] images = pipe( prompt=prompt, input_images=input_images, height=1024, width=1024, guidance_scale=2, img_guidance_scale=1.6, seed=0) images[0].save("./imgs/demo_cases/skeletal2img.png") print("input_image: ") for img in input_images: Image.open(img).show() print("output:") images[0].show() 62%|██████▏ | 31/50 [00:57<00:33, 1.79s/it] 100%|██████████| 50/50 [01:31<00:00, 1.83s/it] input_image:

output:

Our model can perform complex visual condition image generation, directly generating another person's pose from an image end-to-end without skeleton recognition.

from PIL import Image prompt="Following the pose of this image <|image_1|>, generate a new photo: A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him." input_images=["./imgs/demo_cases/edit.png"] images = pipe( prompt=prompt, input_images=input_images, height=1024, width=1024, guidance_scale=2, img_guidance_scale=1.6, seed=123) images[0].save("./imgs/demo_cases/same_pose.png") print("input_image: ") for img in input_images: Image.open(img).show() print("output:") images[0].show() 26%|██▌ | 13/50 [00:25<01:06, 1.80s/it] 100%|██████████| 50/50 [01:31<00:00, 1.83s/it] input_image:

output:

Subject-Driven Ability: Our model can identify the described subject in multi-person images and generate group images of individuals from multiple sources. This end-to-end process requires no additional recognition or segmentation, highlighting OmniGen's flexibility and versatility.

from PIL import Image prompt="A professor and a boy are reading a book together. The professor is the middle man in <|image_1|>. The boy is the boy holding a book in <|image_2|>." input_images=["./imgs/demo_cases/AI_Pioneers.jpg", "./imgs/demo_cases/same_pose.png"] images = pipe( prompt=prompt, input_images=input_images, height=1024, width=1024, guidance_scale=2.5, img_guidance_scale=1.6, separate_cfg_infer=True, seed=0) images[0].save("./imgs/demo_cases/entity.png") print("input_image: ") for img in input_images: Image.open(img).show() print("output:") images[0].show() 0%| | 0/50 [00:00<?, ?it/s] 100%|██████████| 50/50 [03:31<00:00, 4.23s/it] input_image:

output:

2. 训练lora

Fine-tuning Omnigen can better help you handle specific image generation tasks. For example, by fine-tuning on a person's images, you can generate multiple pictures of that person while maintaining task consistency.

A lot of previous work focused on designing new networks to facilitate specific tasks. For instance, ControlNet was proposed to handle image conditions, and IP-Adapter was constructed to maintain ID features. If you want to perform new tasks, you need to build new architectures and repeatedly debug them. Adding and adjusting extra network parameters is usually time-consuming and labor-intensive, which is not user-friendly and cost-efficient enough. However, with Omnigen, all of this becomes very simple.

By comparison, Omnigen can accept multi-modal conditional inputs and has been pre-trained on various tasks. You can fine-tune it on any task without designing specialized networks like ControlNet or IP-Adapter for a specific task.

All you need to do is prepare the data and start training. You can break the limitations of previous models, allowing Omnigen to accomplish a variety of interesting tasks, even those that have never been done before.

Installation

git clone https://github.com/VectorSpaceLab/OmniGen.gitcd OmniGen
pip install -e .

Full fine-tuning

Fine-tuning command

accelerate launch 
    --num_processes=1 
    --use_fsdp 
    --fsdp_offload_params false 
    --fsdp_sharding_strategy SHARD_GRAD_OP 
    --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP 
    --fsdp_transformer_layer_cls_to_wrap Phi3DecoderLayer 
    --fsdp_state_dict_type FULL_STATE_DICT 
    --fsdp_forward_prefetch false 
    --fsdp_use_orig_params True 
    --fsdp_cpu_ram_efficient_loading false 
    --fsdp_sync_module_states True 
    train.py 
    --model_name_or_path Shitao/OmniGen-v1 
    --json_file ./toy_data/toy_data.jsonl 
    --image_path ./toy_data/images 
    --batch_size_per_device 1 
    --lr 2e-5 
    --keep_raw_resolution 
    --max_image_size 1024 
    --gradient_accumulation_steps 1 
    --ckpt_every 100 
    --epochs 100 
    --log_every 1 
    --results_dir ./results/toy_finetune

Some important arguments:

  • num_processes: number of GPU to use for training

  • model_name_or_path: path to the pretrained model

  • json_file: path to the json file containing the training data, e.g., ./toy_data/toy_data.jsonl

  • image_path: path to the image folder, e.g., ./toy_data/images

  • batch_size_per_device: batch size per device

  • lr: learning rate

  • keep_raw_resolution: whether to keep the original resolution of the image, if not, all images will be resized to (max_image_size, max_image_size)

  • max_image_size: max image size

  • gradient_accumulation_steps: number of steps to accumulate gradients

  • ckpt_every: number of steps to save checkpoint

  • epochs: number of epochs

  • log_every: number of steps to log

  • results_dir: path to the results folder

The data format of json_file is as follows:

{    "instruction": str, 
    "input_images": [str, str, ...], 
    "output_images": str
}

You can see a toy example in ./toy_data/toy_data.jsonl.

If an OOM(Out of Memory) issue occurs, you can try to decrease the batch_size_per_device or max_image_size. You can also try to use LoRA instead of full fine-tuning.

Inference

The checkpoint can be found at {results_dir}/checkpoints/*. You can use the following command to load saved checkpoint:

from OmniGen import OmniGenPipeline

pipe = OmniGenPipeline.from_pretrained("checkpoint_path")  # e.g., ./results/toy_finetune/checkpoints/0000200




相关标签: