Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The tag name is always person #557

Open
3083156185 opened this issue Dec 26, 2024 · 2 comments
Open

The tag name is always person #557

3083156185 opened this issue Dec 26, 2024 · 2 comments

Comments

@3083156185
Copy link

When I do video inference, I enter the cue text, e.g. aircraft, and save the resulting video, and it does detect the aircraft, but the label name is always person. i tried to print the inference result, and the mapped labels are always tensor([0]) . What is the cause of this, the detection is spot on but the tag name is always person.

python video_demo.py /root/autodl-tmp/YOLO-World/configs/pretrain/yolo_world_v2_l_vlpan_bn_2e-3_100e_4x8gpus_obj365v1_goldg_train_1280ft_lvis_minival.py /root/autodl-tmp/YOLO-World/pretrained_weights/yolo_world_v2_l_obj365v1_goldg_pretrain_1280ft-9babe3f6.pth /root/autodl-tmp/YOLO-World/source_data/9401.mp4 'airplane' --out /root/autodl-tmp/YOLO-World/source_data/5.mp4

@3083156185
Copy link
Author

It didn't solve the problem.

@ycyg8
Copy link

ycyg8 commented Feb 25, 2025

modified the video_demo.py file by mimicking the image_demo.py file and was able to label the video correctly.
`
import argparse
import sys
import cv2
import mmcv
import torch
from mmengine.dataset import Compose
from mmdet.apis import init_detector
from mmengine.utils import track_iter_progress

from mmyolo.registry import VISUALIZERS
import os
import supervision as sv

BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator(thickness=1)
MASK_ANNOTATOR = sv.MaskAnnotator()

class LabelAnnotator(sv.LabelAnnotator):

@staticmethod
def resolve_text_background_xyxy(
    center_coordinates,
    text_wh,
    position,
):
    center_x, center_y = center_coordinates
    text_w, text_h = text_wh
    return center_x, center_y, center_x + text_w, center_y + text_h

LABEL_ANNOTATOR = LabelAnnotator(text_padding=4,
text_scale=0.5,
text_thickness=1)

def parse_args():
parser = argparse.ArgumentParser(description='YOLO-World video demo')
parser.add_argument('--config',
default='configs/pretrain/'
'yolo_world_v2_x_vlpan_bn_2e-3_100e_4x8gpus_obj365v1_goldg_train_1280ft_lvis_minival.py',
help='test config file path')
parser.add_argument('--checkpoint',
default='weights/yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain_1280ft-14996a36.pth',
help='checkpoint file')
parser.add_argument('--video',
default='demo/sample_images/car.mp4',
help='video file path')
parser.add_argument(
'--text',
default='cat,dog,pig,car',
help=
'text prompts, including categories separated by a comma or a txt file with each line as a prompt.'
)
parser.add_argument('--device',
default='cuda:0',
help='device used for inference')
parser.add_argument('--score-thr',
default=0.2,
type=float,
help='confidence score threshold for predictions.')
parser.add_argument('--out', default='demo_outputs/cat.mp4',type=str, help='output video file')
parser.add_argument('--frame-output-dir', default='demo_outputs/frames', type=str, help='directory to save frames')
args = parser.parse_args()
return args

def inference_detector(model, image, texts, test_pipeline, score_thr=0.3):
data_info = dict(img_id=0, img=image, texts=texts)
data_info = test_pipeline(data_info)
data_batch = dict(inputs=data_info['inputs'].unsqueeze(0),
data_samples=[data_info['data_samples']])

with torch.no_grad():
    output = model.test_step(data_batch)[0]
    pred_instances = output.pred_instances
    pred_instances = pred_instances[pred_instances.scores.float() >
                                    score_thr]
output.pred_instances = pred_instances

return output

def prepare_frame_output_dir(frame_output_dir):
"""Prepare the directory to save frames."""
if os.path.exists(frame_output_dir):
# Clear all files in the directory
for filename in os.listdir(frame_output_dir):
file_path = os.path.join(frame_output_dir, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
os.rmdir(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
else:
# Create the directory if it does not exist
os.makedirs(frame_output_dir)

def main():
args = parse_args()

model = init_detector(args.config, args.checkpoint, device=args.device)


model.cfg.test_dataloader.dataset.pipeline[
    0].type = 'mmdet.LoadImageFromNDArray'
test_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)

if args.text.endswith('.txt'):
    with open(args.text) as f:
        lines = f.readlines()
    texts = [[t.rstrip('\r\n')] for t in lines] + [[' ']]
else:
    texts = [[t.strip()] for t in args.text.split(',')] + [[' ']]


print("model.dataset_meta[classes]1:", model.dataset_meta["classes"])
model.reparameterize(texts)
print("model.dataset_meta[classes]2:",model.dataset_meta["classes"])



visualizer = VISUALIZERS.build(model.cfg.visualizer)

visualizer.dataset_meta = model.dataset_meta

video_reader = mmcv.VideoReader(args.video)
video_writer = None
if args.out:
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(
        args.out, fourcc, video_reader.fps,
        (video_reader.width, video_reader.height))
    # Prepare the directory to save frames
prepare_frame_output_dir(args.frame_output_dir)
frames = [frame for frame in video_reader]
for idx, frame in enumerate(track_iter_progress(frames, file=sys.stdout)):
# for frame in track_iter_progress(frames, file=sys.stdout):
    result = inference_detector(model,
                                frame,
                                texts,
                                test_pipeline,
                                score_thr=args.score_thr)
    # print('result.pred_instances.labels:',result.pred_instances.labels)
    # print('result.pred_instances.labels:', result.pred_instances.labels[0])
    # print('result.pred_instances.labels:', result.pred_instances.labels[12])
    # result.pred_instances类型为 tensor (识别到的物体个数,数据形式)
    pred_instances = result.pred_instances.cpu().numpy()

    if 'masks' in pred_instances:
        masks = pred_instances['masks']
    else:
        masks = None
    # 将预测结果转换为Detections对象,并添加标签。
    detections = sv.Detections(xyxy=pred_instances['bboxes'],
                               class_id=pred_instances['labels'],
                               confidence=pred_instances['scores'],
                               mask=masks)
    labels = [
        f"{texts[class_id][0]} {confidence:0.2f}" for class_id, confidence in
        zip(detections.class_id, detections.confidence)
    ]
    image = BOUNDING_BOX_ANNOTATOR.annotate(frame, detections)
    image = LABEL_ANNOTATOR.annotate(image, detections, labels=labels)
    # Save each frame as an image
    frame_filename = os.path.join(args.frame_output_dir, f"frame_{idx:06d}.jpg")
    cv2.imwrite(frame_filename, image)
    # visualizer.add_datasample(name='video',
    #                           image=frame,
    #                           data_sample=result,
    #                           draw_gt=False,
    #                           show=False,
    #                           pred_score_thr=args.score_thr)
    # frame = visualizer.get_image()

    if args.out:
        video_writer.write(image)
    # import os
    # # Save each frame as an image
    # frame_filename = os.path.join(args.frame_output_dir, f"frame_{idx:06d}.jpg")
    # cv2.imwrite(frame_filename, frame)
if video_writer:
    video_writer.release()

if name == 'main':
main()

`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants
@3083156185 @ycyg8 and others