您的位置:首页 > 汽车 > 时评 > 果蔬识别系统性能优化之路(五)

果蔬识别系统性能优化之路(五)

2024/12/26 9:29:18 来源:https://blog.csdn.net/cpa0701/article/details/142257056  浏览:    关键词:果蔬识别系统性能优化之路(五)

目录

    • 前情提要
      • 剩下问题
    • 解决方案
      • 新建storeFeature表
      • 实现ivf的动态增删改查
    • 结语

前情提要

果蔬识别系统性能优化之路(四)

剩下问题

  1. 新建store_feature表,关联storeCode和featureId表,对数据库进行规范化,创建一个新的表来映射storeCode与feature的关系,从而可以使用简单的WHERE条件来充分利用索引
  2. 实现对特征向量ivf的增删改查

解决方案

新建storeFeature表

  1. 新建store表,storeFeature表
import { Entity, PrimaryGeneratedColumn, Column, OneToMany } from 'typeorm';
import { StoreFeature } from '../../feature/entities/store-feature.entity';@Entity()
export class Store {@PrimaryGeneratedColumn()id: number;@Column({ unique: true })storeCode: string;@Column({ nullable: true })storeName: string;@OneToMany(() => StoreFeature, (storeFeature) => storeFeature.store)storeFeatures: StoreFeature[];
}
import { Entity, ManyToOne, JoinColumn, PrimaryGeneratedColumn } from 'typeorm';
import { Store } from '../../store/entities/store.entity';
import { Feature } from './feature.entity';@Entity()
export class StoreFeature {@PrimaryGeneratedColumn()id: number;@ManyToOne(() => Store, { onDelete: 'CASCADE' })@JoinColumn({ name: 'storeCode', referencedColumnName: 'storeCode' })store: Store;@ManyToOne(() => Feature, { onDelete: 'CASCADE' })@JoinColumn({ name: 'featureId', referencedColumnName: 'id' })feature: Feature;
}

storeFeature表关联store表和feature表

  1. feature.service大改造
import { Injectable } from '@nestjs/common';
import { CreateFeatureDto } from './dto/create-feature.dto';
import { Feature } from './entities/feature.entity';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository, In } from 'typeorm';
import { RedisService } from '../redis/redis.service';
import { HttpService } from '@nestjs/axios';
import { firstValueFrom } from 'rxjs';
import * as FormData from 'form-data';
import { Img } from '../img/entities/img.entity';
import { Store } from '../store/entities/store.entity';
import { StoreFeature } from './entities/store-feature.entity';@Injectable()
export class FeatureService {constructor(@InjectRepository(Feature)private readonly featureRepository: Repository<Feature>,@InjectRepository(Img)private readonly imgRepository: Repository<Img>,@InjectRepository(Store)private readonly storeRepository: Repository<Store>,@InjectRepository(StoreFeature)private readonly storeFeatureRepository: Repository<StoreFeature>,private readonly httpService: HttpService,private readonly redisService: RedisService,) {}/*** 创建* @param file* @param createFeatureDto* @param needSync //是否需要同步redis,默认为true*/async create(file: Express.Multer.File, createFeatureDto: CreateFeatureDto, needSync: boolean = true): Promise<Feature> {const img = this.imgRepository.create({img: file.buffer,});await this.imgRepository.save(img);const [feature, store] = await Promise.all([new Promise(async (resolve) => {const feature: Feature = this.featureRepository.create({...createFeatureDto,imgId: img.id,});await this.featureRepository.save(feature);resolve(feature);}),new Promise(async (resolve) => {let store = await this.storeRepository.findOne({ where: { storeCode: createFeatureDto.storeCode } });if (!store) {store = this.storeRepository.create({storeCode: createFeatureDto.storeCode,storeName: createFeatureDto.storeName,});await this.storeRepository.save(store);}resolve(store);}),]);const storeFeature = this.storeFeatureRepository.create({feature,store,});await this.storeFeatureRepository.save(storeFeature);needSync && await this.syncRedis(createFeatureDto.storeCode);return feature as Feature;}/*** 同步redis* @param storeCode*/async syncRedis(storeCode: string) {const url = 'http://localhost:5000/sync'; // Python 服务的 URLconst s = Date.now();const response = await firstValueFrom(this.httpService.post(url, { storeCode }));const { ids } = response.data;await this.redisService.set(`${storeCode}-featureDatabase`, JSON.stringify(ids));const e = Date.now();console.log(`门店:${storeCode},同步redis耗时:${e - s}ms`);}/*** 查询所有* @param storeCode* @param selectP*/async findAll(storeCode: string, selectP?: string[]) {return await this.featureRepository.createQueryBuilder('feature').select(selectP).innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId').innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode').where('store.storeCode = :storeCode', { storeCode }).getMany();}/*** 查询特性及其关联的图像* @param storeCode*/async findAllWithImage(storeCode: string): Promise<Feature[]> {return await this.featureRepository.createQueryBuilder('feature').leftJoinAndSelect('feature.img', 'img').innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId').innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode').where('store.storeCode = :storeCode', { storeCode }).getMany();}/*** 删除门店所有数据* @param storeCode*/async removeAll(storeCode: string): Promise<void> {const store = await this.storeRepository.findOne({ where: { storeCode }, relations: ['storeFeatures'] });if (!store) {return;}// 批量删除 storeFeatures 和 storeif (store.storeFeatures.length > 0) {await this.storeFeatureRepository.query('DELETE FROM store_feature WHERE id IN (?)', [store.storeFeatures.map(sf => sf.id)]);}await this.storeRepository.remove(store);  // 删除 storeconst unreferencedFeatures = await this.featureRepository.createQueryBuilder('feature').leftJoinAndSelect('feature.img', 'img').leftJoin('feature.storeFeatures', 'storeFeature').where('storeFeature.id IS NULL') // 这里的条件确保我们只选择那些没有其他引用的 feature.getMany();// 批量删除未引用的 featuresif (unreferencedFeatures.length > 0) {for (const feature of unreferencedFeatures) {await this.remove(feature);}}await this.redisService.del(`${storeCode}-featureDatabase`);await this.syncRedis(storeCode);}/*** 预测* @param file* @param num* @param storeCode* @param justPredict* @param needList*/async predict(file: Express.Multer.File,num: string = '5',storeCode: string,justPredict: string = 'false',needList: boolean = false,) {const PYTHON_SERVICE_URL = 'http://localhost:5000/predict'; // Python service URLconst REDIS_KEY_PREFIX = '-featureDatabase';const startTime = Date.now();const numInt = parseInt(num);const isJustPredict = justPredict === 'true';try {// Prepare form dataconst formData = new FormData();formData.append('file', file.buffer, file.originalname);formData.append('storeCode', storeCode);formData.append('justPredict', justPredict);// Send request to Python serviceconst response = await firstValueFrom(this.httpService.post(PYTHON_SERVICE_URL, formData));const { features, index, predictTime } = response.data;if (isJustPredict) {return this.buildResponse([], features, predictTime, startTime, numInt);}// Retrieve feature database from Redisconst featureDatabaseStr = await this.redisService.get(`${storeCode}${REDIS_KEY_PREFIX}`);if (!featureDatabaseStr) {return this.buildResponse([], features, predictTime, startTime, numInt);}// Parse the Redis result and filter the IDsconst featureDatabase = JSON.parse(featureDatabaseStr);const ids = index.map((idx: number) => featureDatabase[idx]);if (!ids.length) {return this.buildResponse([], features, predictTime, startTime, numInt);}// Query for features in the databaseconst featureList = await this.featureRepository.createQueryBuilder('feature').where('feature.id IN (:...ids)', { ids }).orderBy(`FIELD(feature.id, ${ids.map((id: any) => `'${id}'`).join(', ')})`, 'ASC').getMany();// Filter to ensure unique labelsconst uniqueList = this.filterUniqueFeatures(featureList, numInt);const result = this.buildResponse(uniqueList, features, predictTime, startTime, numInt);return needList ? { ...result, featureList: featureList.map(({ features, ...rest }) => rest) } : result;} catch (error) {throw new Error(`Prediction failed: ${error.message}`);}}private filterUniqueFeatures(featureList: any[], limit: number) {const uniqueList = [];for (const feature of featureList) {if (!uniqueList.some(f => f.label === feature.label)) {uniqueList.push(feature);}if (uniqueList.length === limit) break;}return uniqueList;}private buildResponse(list: any[], features: any, predictTime: string, startTime: number, num: number) {const totalTime = `${Date.now() - startTime}ms`;return {predictTime,[`top${num}`]: list.map(({ features, ...rest }) => rest),features,totalTime,};}/*** 计算余弦相似度* @param vecA* @param vecB*/cosineSimilarity(vecA: number[], vecB: number[]): number {if (vecA.length !== vecB.length) {throw new Error('Vectors must be of the same length');}const dotProduct = vecA.reduce((sum, value, index) => sum + value * vecB[index], 0);const magnitudeA = Math.sqrt(vecA.reduce((sum, value) => sum + value * value, 0));const magnitudeB = Math.sqrt(vecB.reduce((sum, value) => sum + value * value, 0));return dotProduct / (magnitudeA * magnitudeB);}/*** 查找相似* @param inputFeatures* @param num* @param storeCode*/async findTopNSimilar(inputFeatures: number[], num: number, storeCode: string): Promise<{label: string;similarity: number}[]> {const featureDatabaseStr = await this.redisService.get(`${storeCode}-featureDatabase`);if (!featureDatabaseStr) {return [];}const featureDatabase = JSON.parse(featureDatabaseStr);const similarities = featureDatabase.map(({ features, label }) => {let similarity = 0;if (features) {similarity = this.cosineSimilarity(inputFeatures, features);}return { label: label as string, similarity: similarity as number };});similarities.sort((a: { similarity: number; }, b: { similarity: number; }) => b.similarity - a.similarity);const uniqueLabels = new Set<string>();const topNUnique: { label: string; similarity: number; }[] = [];for (const item of similarities) {if (!uniqueLabels.has(item.label as string)) {uniqueLabels.add(item.label);item.similarity = Math.round(item.similarity * 100) / 100;topNUnique.push(item);if (topNUnique.length === num) break;}}return topNUnique;}/*** 根据名称查询* @param label* @param storeCode*/async getByName(label: string, storeCode: string): Promise<Feature[]> {return await this.featureRepository.createQueryBuilder('feature').leftJoinAndSelect('feature.img', 'img').innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId').innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode').where('store.storeCode = :storeCode', { storeCode }).andWhere('feature.label = :label', { label }).getMany();}/*** 根据名称向量个数查询* @param label* @param storeCode*/async getCountByLabel(label: string, storeCode: string): Promise<number> {return await this.featureRepository.createQueryBuilder('feature').leftJoinAndSelect('feature.img', 'img').innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId').innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode').where('store.storeCode = :storeCode', { storeCode }).andWhere('feature.label = :label', { label }).getCount();}/*** 批量学习* @param files* @param createFeatureDto*/async batchStudy(files: Express.Multer.File[], createFeatureDto: CreateFeatureDto) {const list = [];for (const file of files) {try {const { features: f } = await this.predict(file, '5', createFeatureDto.storeCode, 'true');const feature = await this.create(file, {...createFeatureDto,features: f,}, false);// 创建一个副本,不包含 `features` 属性const { features, ...featureWithoutFeatures } = feature;// 将不包含 `features` 属性的对象推送到数组中list.push(featureWithoutFeatures);} catch (e) {console.error(e);}}await this.syncRedis(createFeatureDto.storeCode);return list;}/*** 删除门店的特征值数据* @param feature*/async remove(feature: Feature) {await this.featureRepository.remove(feature);await this.imgRepository.remove(feature.img);}/*** 批量删除* @param ids* @param storeCode*/async batchRemove(ids: string, storeCode: string) {const list = ids.split(',').map(id => +id);// 批量查询所有相关的 Featureconst features = await this.featureRepository.find({where: { id: In(list) },relations: ['img', 'storeFeatures'],});for (const feature of features) {feature && await this.remove(feature);await this.storeFeatureRepository.remove(feature.storeFeatures);}await this.syncRedis(storeCode);}/*** 导入数据* @param storeCode* @param sourceStoreCode* @param storeName*/async importData(storeCode: string, sourceStoreCode?: string, storeName?: string) {let storeFeatures = [];// 第一步:查询指定 storeCode 关联的所有 featureIdconst storeFeatureIds = await this.storeFeatureRepository.createQueryBuilder('storeFeature').select('storeFeature.featureId').where('storeFeature.storeCode = :storeCode', { storeCode }).getRawMany();// 提取出 featureId 列表const featureIdsToExclude = storeFeatureIds.map(row => row.featureId);let distinctFeatureIds = [];if (featureIdsToExclude.length === 0) {distinctFeatureIds = await this.storeFeatureRepository.createQueryBuilder('storeFeature').select('DISTINCT storeFeature.featureId')  // 确保 featureId 唯一.getRawMany();} else {// 第二步:排除这些 featureId,并确保 featureId 唯一distinctFeatureIds = await this.storeFeatureRepository.createQueryBuilder('storeFeature').select('DISTINCT storeFeature.featureId')  // 确保 featureId 唯一.where('storeFeature.featureId NOT IN (:...featureIdsToExclude)', { featureIdsToExclude })  // 排除 featureId.getRawMany();}const featureIds = distinctFeatureIds.map(record => record.featureId);if (!sourceStoreCode) {storeFeatures = await this.featureRepository.createQueryBuilder('feature').leftJoinAndSelect('feature.img', 'img').whereInIds(featureIds).getMany();} else {storeFeatures = await this.featureRepository.createQueryBuilder('feature').leftJoinAndSelect('feature.img', 'img').innerJoin('feature.storeFeatures', 'storeFeatures').whereInIds(featureIds).andWhere('storeFeatures.storeCode = :storeCode', { storeCode: sourceStoreCode })  // 使用参数化查询.getMany();}let targetStore = await this.storeRepository.findOne({ where: { storeCode: storeCode } });if (!targetStore) {targetStore = this.storeRepository.create({storeCode: storeCode,storeName: storeName,});await this.storeRepository.save(targetStore);}// Create new StoreFeature records for the target storeCodeconst newStoreFeatures = storeFeatures.map((feature: Feature) => ({store: targetStore,feature, // Reuse the existing feature}));// Save new StoreFeature recordsconst storeFeatureInstances = this.storeFeatureRepository.create(newStoreFeatures);await this.storeFeatureRepository.save(storeFeatureInstances);await this.syncRedis(storeCode);return `同步完成,共导入${storeFeatures.length}条数据`;}async init() {const distinctStoreCodes = await this.storeRepository.createQueryBuilder('store').select('store.storeCode').distinct(true).getRawMany();const syncList = [];for (const row of distinctStoreCodes) {const storeCode = row.store_storeCode;syncList.push(this.syncRedis(storeCode));}await Promise.all(syncList);console.log('初始化完成');}
}
  1. 结果:并没有提升多少,但好在关系更清晰,为之后的拓展打了基础

实现ivf的动态增删改查

  1. 结论:ivf无法在不训练只增加的情况下进行新增向量的识别,所以每次新增向量必须重新进行训练和添加
  2. python端ivf改造
    detect.py(识别和同步方法)
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
import numpy as np
import time
import gc
from ivf import IVFPQ
from feature import get_feature_by_store_code
import orjson
from concurrent.futures import ThreadPoolExecutor# 加载预训练的 MobileNetV2 模型,不包含顶部的分类层
model = MobileNetV2(input_shape=(224, 224, 3), weights='imagenet', include_top=False, pooling='avg')class MainDetect:# 初始化def __init__(self):super().__init__()# 模型初始化self.image_id = Noneself.image_features = Noneself.model = tf.keras.models.load_model("models/custom/my-model.h5")self.ivfObj = {}def classify_image(self, image_data, store_code, just_predict):# Load and preprocess imageimg = tf.image.decode_image(image_data, channels=3)img = tf.image.resize(img, [224, 224])img = tf.expand_dims(img, axis=0)  # Add batch dimension# Run model predictionstart_time = time.time()outputs = model.predict(img)# outputs = self.model.predict(outputs)# prediction = tf.divide(outputs, tf.norm(outputs))i = []if just_predict == "false":if store_code + '-featureDatabase' in self.ivfObj:i = self.ivfObj[store_code + '-featureDatabase'].search(outputs)i = i.flatten().tolist()end_time = time.time()# Calculate elapsed timeelapsed_time = end_time - start_time# Flatten the outputs and return them# output_data = prediction.numpy().flatten().tolist()output_data = outputs.flatten().tolist()# Force garbage collection to free up memorydel img, outputs, end_time, start_time  # Ensure variables are deletedgc.collect()return {"outputs": output_data, "time": f"{elapsed_time * 1000:.2f}ms", "index": i}def sync(self, store_code):if store_code + '-featureDatabase' in self.ivfObj:del self.ivfObj[store_code + '-featureDatabase']data = get_feature_by_store_code(store_code)if len(data) == 0:return []else:def parse_features(item):return orjson.loads(item['features'])with ThreadPoolExecutor() as executor:features_list = list(executor.map(parse_features, data))# 提取所有特征并转换为 NumPy 数组features = np.array(features_list, dtype=np.float32)self.ivfObj[store_code + '-featureDatabase'] = IVFPQ(features)ids = [item['id'] for item in data]return ids

ivf.py(ivf构造)

import faiss
import numpy as npnum_threads = 8
faiss.omp_set_num_threads(num_threads)class IVFPQ:def __init__(self, features, nlist=100, m=16, n_bits=8):d = features.shape[1]# 创建量化器quantizer = faiss.IndexFlatL2(d)  # 使用L2距离进行量化self.index = faiss.IndexIVFFlat(quantizer, d, nlist)# self.index = faiss.IndexIVFPQ(quantizer, d, nlist, m, n_bits)# 训练索引count = 3900if features.size >= count * d:self.index.train(features)if features.size > 1000 * d:batch_size = 1000  # 每次处理1000个特征for i in range(0, len(features), batch_size):self.index.add(features[i:i + batch_size])else:self.index.add(features)else:points = int(count - features.size / d)np.random.seed(points)xb = np.random.random((points, d)).astype('float32')  # 模拟数据库中的特征向量combined_features = np.vstack((features, xb))  # Stack them vertically# 训练索引self.index.train(combined_features)self.index.add(combined_features)  # 将特征向量添加到索引中def search(self, xq, k=100):d, i = self.index.search(xq, k)return idef add(self, xb):self.index.add(xb)def train(self, xb):self.index.train(xb)def sync(self, features):for i in range(len(features)):self.add(features[i])

结语

这个项目优化到这差不多告一段落了,后续还有啥优化点会继续跟进,稍后会把整个架构图和功能点都梳理一遍

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com