原文:yolov8 分类太阳能板 - 知乎 (zhihu.com)
1、数据集
https://github.com/zae-bayern/elpv-datasetgithub.com/zae-bayern/elpv-dataset
2、数据分析
import matplotlib.pyplot as plt
import ostrain_dir = "./images"
valid_extensions=('.jpg', '.png', '.jpeg')categories = ['poly', 'mono']
category_count = {}for each_category in categories:folder_path = os.path.join(train_dir, each_category)valid_images = [file for file in os.listdir(folder_path) if file.lower().endswith(valid_extensions)]category_count[each_category] = len(valid_images)fig, ax = plt.subplots(figsize=(10, 4))# Bar chart
bar_plot = plt.barh(list(category_count.keys()), list(category_count.values()), 0.5)
plt.title('Elpv Type Distribution')
plt.xlabel('Count')
plt.ylabel('Elpv Type')
for i, bar in enumerate(bar_plot):plt.text(bar.get_width(), bar.get_y() + bar.get_height() / 2, str(list(category_count.values())[i]), ha='left', va='center')plt.show()sample_size = sum(category_count.values())class_dist = {key : val/sample_size for key, val in category_count.items()}fig, ax = plt.subplots(figsize=(10, 4))# Bar chart
bar_plot = plt.barh(list(class_dist.keys()), list(class_dist.values()), 0.6)
plt.title('Class Distribution')
plt.xlabel('Class')
plt.ylabel('Percentage')for i, bar in enumerate(bar_plot):plt.text(bar.get_width(), bar.get_y() + bar.get_height() / 2, str(round(list(class_dist.values())[i], 3)), ha='left', va='center')plt.show()
三、数据可视化
import matplotlib.pyplot as plt
import os
train_dir = "images"
valid_extensions=('.jpg', '.png', '.jpeg')categories = ['poly', 'mono']plt.figure(figsize=(12, 8))
for i, category in enumerate(categories):folder_path = os.path.join(train_dir, category)image_path = os.path.join(folder_path, os.listdir(folder_path)[0])if not image_path.lower().endswith(valid_extensions):continueimg = plt.imread(image_path)plt.subplot(2, 2, i+1)plt.imshow(img)plt.title(category)plt.axis("off")
plt.tight_layout()
plt.show()
四、模型训练
from ultralytics import YOLOmodel = YOLO('yolov8m-cls.pt') # load a pretrained YOLOv8n classification model# train/pre-tuned the model on our dataset
model.train(data='images', epochs=3)# run the model on test data
res = model.val()# Result saved to runs/classify/val
五、模型预测
from ultralytics import YOLO# Load a model
model = YOLO('yolov8m-cls.pt') # load an official model
model = YOLO('runs/classify/train/weights/best.pt') # load a custom model# Predict with the model
results = model('images/val/mono/cell0001.png', show=True, save=True) # predict on an image
六、pt模型转onnx
from ultralytics import YOLO# Load a model
model = YOLO('yolov8m-cls.pt') # load an official model
model = YOLO('runs/classify/train/weights/best.pt') # load a custom trained model# Export the model
model.export(format='onnx')
七、onnx推理
import onnxruntime as rt
import numpy as np
import cv2
import matplotlib.pyplot as pltif __name__ == '__main__':height, width = 224, 224# img0 = cv2.imread('images/val/mono/cell0001.png')img0 = cv2.imread('images/val/poly/cell0065.png')categories = ['mono', 'poly']x_scale = img0.shape[1] / widthy_scale = img0.shape[0] / heightimg = img0 / 255.img = cv2.resize(img, (width, height))img = np.transpose(img, (2, 0, 1))data = np.expand_dims(img, axis=0)sess = rt.InferenceSession('v8-cls.onnx')input_name = sess.get_inputs()[0].namelabel_name = sess.get_outputs()[0].namepred = sess.run([label_name], {input_name: data.astype(np.float32)})[0][0]# print(pred)max_index = np.argmax(pred, axis=0)# print(max_index)print('预测的结果为:', categories[max_index])