基于OpenAI Clip模型的自动图像分类 基于OpenAI Clip模型的自动图像分类一 环境安装pip install githttps://github.com/openai/CLIP.gitpip install torch torchvision pip install githttps://github.com/openai/CLIP.git pip install pillow matplotlib二 基本使用import clip import torchfromPIL import Image import numpyasnpclassCLIPImageClassifier:def__init__(self,model_nameViT-B/32,deviceNone): 初始化CLIP分类器 Args:model_name:CLIP模型名称可选ViT-B/32,ViT-B/16,ViT-L/14device:运行设备cuda或cpuifdeviceisNone:self.devicecudaiftorch.cuda.is_available()elsecpuelse:self.devicedeviceprint(f使用设备: {self.device})print(f加载模型: {model_name})# 加载CLIP模型和预处理函数self.model,self.preprocessclip.load(model_name,deviceself.device)self.model.eval()defclassify_image(self,image_path,class_names,top_k5): 对单张图像进行分类 Args:image_path:图像路径 class_names:类别名称列表 top_k:返回top-k预测结果 Returns:包含预测结果的字典列表# 加载并预处理图像try:imageImage.open(image_path).convert(RGB)except Exceptionase:print(f无法加载图像: {e})returnNoneimage_inputself.preprocess(image).unsqueeze(0).to(self.device)# 处理文本类别text_inputstorch.cat([clip.tokenize(fa photo of {c})forcinclass_names]).to(self.device)# 推理withtorch.no_grad():image_featuresself.model.encode_image(image_input)text_featuresself.model.encode_text(text_inputs)# 计算相似度image_featuresimage_features/image_features.norm(dim-1,keepdimTrue)text_featurestext_features/text_features.norm(dim-1,keepdimTrue)similarity(100.0*image_features text_features.T).softmax(dim-1)# 获取top-k结果probssimilarity.cpu().numpy()[0]top_indicesnp.argsort(probs)[::-1][:top_k]results[]foridxintop_indices:results.append({class:class_names[idx],probability:float(probs[idx])})returnresultsdefclassify_batch(self,image_paths,class_names,top_k5): 批量分类多张图像 results[]forimage_pathinimage_paths:resultself.classify_image(image_path,class_names,top_k)ifresult:results.append({image_path:image_path,predictions:result})returnresults# 使用示例if__name____main__:# 初始化分类器classifierCLIPImageClassifier(model_nameViT-B/32)# 定义类别可以是任何你想要的类别class_names[cat,dog,bird,car,airplane,beach,mountain,forest,city,ocean,apple,banana,orange,person,bicycle]# 单张图像分类image_pathtest_image.jpg# 替换为你的图像路径resultsclassifier.classify_image(image_path,class_names,top_k3)ifresults:print(\n分类结果:)fori,resultinenumerate(results,1):print(f{i}. {result[class]}: {result[probability]:.2%})