Skip to content

Commit

Permalink
feat: Modify the loading method for SigLIP and Qwen2VL_ViT
Browse files Browse the repository at this point in the history
  • Loading branch information
Hyggge committed Jan 21, 2025
1 parent c4b261c commit fd0aff8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 15 deletions.
8 changes: 4 additions & 4 deletions valley_eagle/model/multimodal_encoder/siglip_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from ...util.vision_encoder_config import siglip_config
from ...util.vision_encoder_config import siglip_config, siglip_processor_config


class SigLipVisionTower(nn.Module):
Expand All @@ -20,13 +20,13 @@ def __init__(self, vision_tower, args, delay_load=False, cache_dir="./cache_dir"
else:
from transformers import SiglipVisionConfig, SiglipVisionModel

self.cfg_only = SiglipVisionConfig.from_pretrained(self.image_tower_name, cache_dir=self.cache_dir)
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name) # dummy-load
self.cfg_only = siglip_config
self.vision_tower = SiglipVisionModel._from_config(siglip_config) # dummy-load

def load_model(self):
from transformers import SiglipImageProcessor, SiglipVisionModel

self.image_processor = SiglipImageProcessor.from_pretrained(self.image_tower_name)
self.image_processor = SiglipImageProcessor.from_dict(siglip_processor_config)
self.vision_tower = SiglipVisionModel._from_config(siglip_config)
self.vision_tower.requires_grad_(False)
self.image_processor.crop_size = self.image_processor.size["height"]
Expand Down
45 changes: 45 additions & 0 deletions valley_eagle/util/vision_encoder_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,48 @@
"_attn_implementation_internal": "flash_attention_2"
}
)


siglip_processor_config = {
"do_normalize": True,
"do_rescale": True,
"do_resize": True,
"image_mean": [
0.5,
0.5,
0.5
],
"image_processor_type": "SiglipImageProcessor",
"image_std": [
0.5,
0.5,
0.5
],
"processor_class": "SiglipProcessor",
"resample": 3,
"rescale_factor": 0.00392156862745098,
"size": {
"height": 384,
"width": 384
}
}

qwen2vl_processor_config = {
"min_pixels": 3136,
"max_pixels": 12845056,
"patch_size": 14,
"temporal_patch_size": 2,
"merge_size": 2,
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"image_processor_type": "Qwen2VLImageProcessor",
"processor_class": "Qwen2VLProcessor"
}
16 changes: 5 additions & 11 deletions valley_eagle_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
from typing import Dict, List, Union
from PIL import Image
from qwen_vl_utils import fetch_image
from transformers import AutoTokenizer, AutoConfig, AutoProcessor, SiglipImageProcessor
from transformers import AutoTokenizer, SiglipImageProcessor, Qwen2VLImageProcessor
from transformers import set_seed

from valley_eagle import conversation as conversation_lib
from valley_eagle.valley_utils import disable_torch_init
from valley_eagle.model.language_model.valley_qwen2 import ValleyQwen2ForCausalLM
from valley_eagle.util.data_util import dynamic_preprocess, preprocess
from valley_eagle.util.mm_utils import process_anyres_image
from valley_eagle.util.vision_encoder_config import siglip_processor_config, qwen2vl_processor_config

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -105,8 +106,8 @@ def __init__(

# Load image preprocessor
self.black_img = black_img
self.image_processor = SiglipImageProcessor.from_pretrained(self.model.config.mm_vision_tower)
self.qwen2vl_processor = AutoProcessor.from_pretrained(self.model.config.eagle_vision_tower, max_pixels=1280 * 28 * 28)
self.image_processor = SiglipImageProcessor.from_dict(siglip_processor_config)
self.qwen2vl_processor = Qwen2VLImageProcessor.from_dict(qwen2vl_processor_config, max_pixels=1280 * 28 * 28)
self.image_processor.crop_size = self.image_processor.size["height"]

def preprocess_images(self, image_binary_list) -> torch.FloatTensor:
Expand Down Expand Up @@ -184,14 +185,7 @@ def __call__(self, request):
for image_file in images_pil:
image = fetch_image({"image": image_file})
image_list.append(image)
messages_qwen.append({"role": "user", "content": [{"type": "text", "text": text}]})
messages_qwen.append({"role": "assistant", "content": [{"type": "text", "text": ""}]})
text = self.qwen2vl_processor.apply_chat_template(messages_qwen[:-1], tokenize=False, add_generation_prompt=True)
text_segs = re.split("<image>", text)
text = "<|vision_start|><|image_pad|><|vision_end|>".join(text_segs[: len(image_list) + 1]) + "".join(
text_segs[len(image_list) + 1 :]
)
data_dict_qwen2vl = self.qwen2vl_processor(text=[text], images=image_list, padding=True, return_tensors="pt")
data_dict_qwen2vl = self.qwen2vl_processor(image_list, return_tensors="pt")

# process messages, get tensors which will be input to model
source = preprocess_multimodal(messages, img_length, self.model.config)
Expand Down

0 comments on commit fd0aff8

Please sign in to comment.