一个YOLO11检测物体的Gradio APP
最近新接触Gradio,尝试着制作了一个YOLO检测物体的 APP,并对检测的物体进行分类计数,绘制热图,该APP部署在huggingface上,可以通过链接 进行体验。
功能不多,实现了对视频中物体数量的统计(也有没识别出来的),算一个半页面。
- 上传.mp4格式的视频,也可以点击左侧的Example使用示例视频,然后点击播放按钮,APP开始处理视频,经过一段时间之后会在右侧播放处理过的视频,并在下方形成物体计数的表格。原始视频下方还有两个参数按钮,一个是YOLO检测物体需要的“把握度”,另外一个是亚采样数。
- 绘制相关的热图并下载物体计数的表格。
代码如下:
import gradio as gr
from gradio_webrtc import WebRTC
from PIL import Image
from ultralytics import YOLO
import cv2
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def stream_object_detection(input_video, confidence_threshold,SUBSAMPLE):"""Perform object detection on a video stream using YOLO model and save the processed video.Args:input_video (str): Path to the input video file.confidence_threshold (float): Confidence threshold for object detection."""# Open the input videovideo_capture = cv2.VideoCapture(input_video)# Get frames per second and calculate desired FPSfps = int(video_capture.get(cv2.CAP_PROP_FPS))desired_fps = fps // SUBSAMPLE# Get video dimensions and halve themwidth = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) // 2# Define the codec and create VideoWriter objectvideo_codec = cv2.VideoWriter_fourcc(*"avc1")output_video_name=f"output.mp4"output_video = cv2.VideoWriter(output_video_name, video_codec, desired_fps, (width, height))# Load the YOLO modelmodel = YOLO("yolo11n.pt")class_colors = {0: (0, 255, 0), # green for class 01: (0, 0, 255), # blue for class 12: (255, 0, 0), # red for class 23: (255, 255, 0), # yellow for class 34: (0, 255, 255), # cyan for class 45: (255, 0, 255), # magenta for class 56: (128, 0, 0), # dark red for class 67: (0, 128, 0), # dark green for class 78: (128, 128, 0), # dark yellow for class 89: (0, 128, 128), # dark cyan for class 910: (128, 0, 128), # dark magenta for class 1011: (128, 128, 128), # dark gray for class 1112: (192, 192, 192), # light gray for class 1213: (255, 128, 0), # orange for class 1314: (128, 255, 0), # lime for class 14}frame_batch_count = 0frame_batch = []class_count_summary = []while True:# Read a frame from the videosuccess, frame = video_capture.read()if not success:break# Resize and convert frame colorframe = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)# Collect frames for processing,if frame_batch_count % SUBSAMPLE == 0:frame_batch.append(frame)# Process frames in batch,including frames containing in SUBSAMPLE secondif len(frame_batch) == SUBSAMPLE * desired_fps:results = model(frame_batch, conf=confidence_threshold)cls_count={}for array, result in zip(frame_batch, results):# Draw detections on the framefor detection in result.boxes.data.tolist():x1, y1, x2, y2, conf, cls = detectionif conf > confidence_threshold:cls_count[cls]=cls_count.get(cls,0)+1color = class_colors.get(cls, (255, 255, 255))#add box cv2.rectangle(array, (int(x1), int(y1)), (int(x2), int(y2)), color,2)# Add labellabel = result.names[int(cls)]cv2.putText(array, label, (int(x1), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.5,color, 2)# Convert frame color back and write to output videoframe = cv2.cvtColor(array, cv2.COLOR_RGB2BGR)output_video.write(frame)frame_batch = []class_count_summary.append(cls_count) frame_batch_count += 1# Release the video writeroutput_video.release()cls_count_summary_df=pd.DataFrame(class_count_summary)print(cls_count_summary_df.info())return output_video_name, cls_count_summary_df.valuesdef plot_heatmap_csv(cls_count_summary_df):sns.heatmap(cls_count_summary_df.T,cbar=True)plt.savefig('heatmap.png')plot='heatmap.png' return plotdef df_to_csv(cls_count_summary_df):cls_count_summary_df.to_csv('count.csv',index=False)return 'count.csv'import gradio as grwith gr.Blocks(theme=gr.themes.Base()) as app:gr.HTML("""<h1 style='text-align: center'>YOLO视频物体检测 powered by <a href='https://ultralytics.com/' target='_blank'>YOLOv11n</a></h1>""")with gr.Row():with gr.Column():video = gr.Video(label="Video Source")examples= gr.Examples(['3285790-hd_1920_1080_30fps.mp4'], inputs=video)conf_threshold = gr.Slider(label="Confidence Threshold",minimum=0.0,maximum=1.0,step=0.05,value=0.30,)SUBSAMPLE = gr.Slider(label="SUBSAMPLE",minimum=1,maximum=10,step=1,value=2,)with gr.Column():with gr.Tab("main output"):output_video = gr.Video(label="Processed Video", autoplay=True)output_df=gr.DataFrame(label="Object count",headers=['Car','Truck','Person','Bus','Motocycle','Bike','Backpack','Parking'])with gr.Tab("plot output"): output_plot=gr.Image(label="Object Count Heatmap", show_download_button=True)output_csv=gr.DownloadButton(label="CSV")video.play(fn=stream_object_detection,inputs=[video, conf_threshold,SUBSAMPLE],outputs=[output_video,output_df],).then(fn=plot_heatmap_csv,inputs=output_df,outputs=output_plot)output_csv.click(fn=df_to_csv,inputs=output_df,outputs=output_csv)if __name__ == "__main__":app.launch()
总结
Grodio还是很有特色的一个制作网页应用的库,适合做深度学习的应用。