文字图像匹配度检测软件(基于CLIP、Transformers等实现)

前言

  • 使用CLIP(对比图文预训练方法)提供的图文匹配度检测接口,使用huggingface基于Transformers的机器模型实现离线翻译,因此输入中英文均可检测。前端图形化界面使用PYQT开发,并使用了qdarkstyle进行优化,具体效果如下图所示:

在这里插入图片描述

使用方法:

  • 左边一栏是候选文字语句,右边一栏是对应每条文字语句的匹配度(支持中英文)

  • 点击选择图片,如为我代码中附带的数据集中的图片,那么右边第一列的第一行会附上这张图片的正确描述,如为其他图片,则可以手动输入正确描述,随机抽取中文、英文按钮会下后四行抽取干扰的中文、英文描述,所有候选的5个描述语句均可手动修改

  • 下面是一个例子:

在这里插入图片描述

从结果可以看出,模型对于最贴合图片的那句描述是可以正确识别的,而且效果很好,支持中英文,我自己也做了很多实验测试,代码中也有评估模型准确度的代码testCode.py

部分代码:

主体代码如下,其余代码以及requirements.txt等打包放在我的资源中,可以下载并配置好相关环境后运行

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# -*- coding: utf-8 -*-

# Form implementation generated from reading ui file 'txtimgui.ui'
#
# Created by: PyQt5 UI code generator 5.15.4
#
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again. Do not edit this file unless you know what you are doing.
import random
import os
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import QFileDialog
import torch
from PIL import Image
import translate_main
import clip
import warnings

warnings.filterwarnings("ignore")
global imgNamePath


def getPicName(myLine):
resName = ''
if "#enc#0 " in myLine:
resName = myLine.split("#enc#0 ")[0]
elif "#zhc#1 " in myLine:
resName = myLine.split("#zhc#1 ")[0]
else:
resName = myLine.split("#zhc#0 ")[0]
return resName


def getPicSentence(myLine):
resName = ''
if "#enc#0 " in myLine:
resName = myLine.split("#enc#0 ")[1]
elif "#zhc#1 " in myLine:
resName = myLine.split("#zhc#1 ")[1]
else:
resName = myLine.split("#zhc#0 ")[1]
return resName


class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(800, 600)
MainWindow.setMinimumSize(QtCore.QSize(80, 30))
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.pushButton = QtWidgets.QPushButton(self.centralwidget)
self.pushButton.setGeometry(QtCore.QRect(30, 90, 91, 31))
self.pushButton.setObjectName("pushButton")
self.pushButton.clicked.connect(self.openImage)
self.label = QtWidgets.QLabel(self.centralwidget)
self.label.setGeometry(QtCore.QRect(40, 160, 241, 271))
self.label.setObjectName("label")
self.lineEdit = QtWidgets.QLineEdit(self.centralwidget)
self.lineEdit.setGeometry(QtCore.QRect(120, 90, 181, 31))
self.lineEdit.setObjectName("lineEdit")
self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_2.setGeometry(QtCore.QRect(374, 362, 81, 31))
self.pushButton_2.setObjectName("pushButton_2")
self.pushButton_2.clicked.connect(self.randomExtract)
self.pushButton_3 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_3.setGeometry(QtCore.QRect(514, 362, 81, 31))
self.pushButton_3.setObjectName("pushButton_3")
self.pushButton_3.clicked.connect(self.matching)
self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_5.setGeometry(QtCore.QRect(374, 462, 81, 31))
self.pushButton_5.setObjectName("pushButton_5")
self.pushButton_5.clicked.connect(self.randomExtractEn)
self.widget = QtWidgets.QWidget(self.centralwidget)
self.widget.setGeometry(QtCore.QRect(310, 100, 331, 221))
self.widget.setObjectName("widget")
self.verticalLayout = QtWidgets.QVBoxLayout(self.widget)
self.verticalLayout.setContentsMargins(0, 0, 0, 0)
self.verticalLayout.setObjectName("verticalLayout")
self.lineEdit_2 = QtWidgets.QLineEdit(self.widget)
self.lineEdit_2.setMinimumSize(QtCore.QSize(100, 30))
self.lineEdit_2.setObjectName("lineEdit_2")
self.verticalLayout.addWidget(self.lineEdit_2)
self.lineEdit_3 = QtWidgets.QLineEdit(self.widget)
self.lineEdit_3.setMinimumSize(QtCore.QSize(0, 30))
self.lineEdit_3.setObjectName("lineEdit_3")
self.verticalLayout.addWidget(self.lineEdit_3)
self.lineEdit_4 = QtWidgets.QLineEdit(self.widget)
self.lineEdit_4.setMinimumSize(QtCore.QSize(0, 30))
self.lineEdit_4.setObjectName("lineEdit_4")
self.verticalLayout.addWidget(self.lineEdit_4)
self.lineEdit_5 = QtWidgets.QLineEdit(self.widget)
self.lineEdit_5.setMinimumSize(QtCore.QSize(0, 30))
self.lineEdit_5.setObjectName("lineEdit_5")
self.verticalLayout.addWidget(self.lineEdit_5)
self.lineEdit_6 = QtWidgets.QLineEdit(self.widget)
self.lineEdit_6.setMinimumSize(QtCore.QSize(0, 30))
self.lineEdit_6.setObjectName("lineEdit_6")
self.verticalLayout.addWidget(self.lineEdit_6)
self.widget1 = QtWidgets.QWidget(self.centralwidget)
self.widget1.setGeometry(QtCore.QRect(650, 100, 135, 221))
self.widget1.setObjectName("widget1")
self.verticalLayout_2 = QtWidgets.QVBoxLayout(self.widget1)
self.verticalLayout_2.setContentsMargins(0, 0, 0, 0)
self.verticalLayout_2.setObjectName("verticalLayout_2")
self.lineEdit_7 = QtWidgets.QLineEdit(self.widget1)
self.lineEdit_7.setMinimumSize(QtCore.QSize(0, 30))
self.lineEdit_7.setObjectName("lineEdit_7")
self.verticalLayout_2.addWidget(self.lineEdit_7)
self.lineEdit_8 = QtWidgets.QLineEdit(self.widget1)
self.lineEdit_8.setMinimumSize(QtCore.QSize(0, 30))
self.lineEdit_8.setObjectName("lineEdit_8")
self.verticalLayout_2.addWidget(self.lineEdit_8)
self.lineEdit_9 = QtWidgets.QLineEdit(self.widget1)
self.lineEdit_9.setMinimumSize(QtCore.QSize(0, 30))
self.lineEdit_9.setObjectName("lineEdit_9")
self.verticalLayout_2.addWidget(self.lineEdit_9)
self.lineEdit_10 = QtWidgets.QLineEdit(self.widget1)
self.lineEdit_10.setMinimumSize(QtCore.QSize(0, 30))
self.lineEdit_10.setObjectName("lineEdit_10")
self.verticalLayout_2.addWidget(self.lineEdit_10)
self.lineEdit_11 = QtWidgets.QLineEdit(self.widget1)
self.lineEdit_11.setMinimumSize(QtCore.QSize(0, 30))
self.lineEdit_11.setObjectName("lineEdit_11")
self.verticalLayout_2.addWidget(self.lineEdit_11)
MainWindow.setCentralWidget(self.centralwidget)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)

def openImage(self):
global imgNamePath
# 这里为了方便别的地方引用图片路径,将其设置为全局变量
# 弹出一个文件选择框,第一个返回值imgName记录选中的文件路径+文件名,第二个返回值imgType记录文件的类型
# QFileDialog就是系统对话框的那个类第一个参数是上下文,第二个参数是弹框的名字,第三个参数是默认打开的路径,第四个参数是需要的格式
# 设置try-except防止各种不符合要求的操作导致软件退出
try:
imgNamePath, imgType = QFileDialog.getOpenFileName(self.centralwidget, "选择图片",
'./dataset',
"*.jpg;;*.png;;All Files(*)")
# 通过文件路径获取图片文件,并设置图片长宽为label控件的长、宽
img = QtGui.QPixmap(imgNamePath).scaled(self.label.width(), self.label.height())
# 在label控件上显示选择的图片
self.label.setPixmap(img)
# 显示所选图片的路径
except:
return
# print(imgNamePath)
self.lineEdit.setText(imgNamePath)
try:
resPath = imgNamePath.split('image/')[1]
except:
return
# 卫星
for line in open("./dataset/militray_label.txt", encoding='utf-8'):
if getPicName(line) == resPath:
print(line)
self.lineEdit_2.setText(getPicSentence(line))
# 中文
for line in open("./dataset/ch_label.txt", encoding='GBK'):
if getPicName(line) == resPath:
print(line)
self.lineEdit_2.setText(getPicSentence(line))
# 英文
for line in open("./dataset/enc_label.txt", encoding='GBK'):
if getPicName(line) == resPath:
print(line)
self.lineEdit_2.setText(getPicSentence(line))

def randomExtract(self):
# 随机抽取 图片名字和对应正确描述构成映射 读图片的时候把正确的那句话也放到第一个框里
r1 = random.randint(10, 20)
r2 = random.randint(21, 30)
r3 = random.randint(31, 39)
r4 = random.randint(40, 49)
f = open("./dataset/militray_label.txt", encoding='utf=8')
resList = []
while 1:
lines = f.readlines(10000)
if not lines:
break
for line in lines:
resList.append(getPicSentence(line))
print(resList[r1], resList[r2], resList[r3], resList[r4])
self.lineEdit_3.setText(resList[r1])
self.lineEdit_4.setText(resList[r2])
self.lineEdit_5.setText(resList[r3])
self.lineEdit_6.setText(resList[r4])
f.close()

def randomExtractEn(self):
# 随机抽取英文 图片名字和对应正确描述构成映射 读图片的时候把正确的那句话也放到第一个框里
r1 = random.randint(10, 20)
r2 = random.randint(21, 30)
r3 = random.randint(31, 39)
r4 = random.randint(40, 49)
f = open("./dataset/militray_enc_label.txt", encoding='utf=8')
resList = []
while 1:
lines = f.readlines(10000)
if not lines:
break
for line in lines:
resList.append(getPicSentence(line))
print(resList[r1], resList[r2], resList[r3], resList[r4])
self.lineEdit_3.setText(resList[r1])
self.lineEdit_4.setText(resList[r2])
self.lineEdit_5.setText(resList[r3])
self.lineEdit_6.setText(resList[r4])
f.close()

def matching(self):
t1 = self.lineEdit_2.text()
t2 = self.lineEdit_3.text()
t3 = self.lineEdit_4.text()
t4 = self.lineEdit_5.text()
t5 = self.lineEdit_6.text()
s1, s2, s3, s4, s5 = translate_main.trans(t1, t2, t3, t4, t5)
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
global imgNamePath
image = preprocess(Image.open(imgNamePath)).unsqueeze(0).to(device)
text = clip.tokenize([str(s1), str(s2), str(s3), str(s4), str(s5)]).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("文本图像匹配度:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
prob = str(probs)[2:-2]
print(prob)
t1, t2, t3, t4, t5 = prob.split()
# 格式化输出 更好看
# 使用python内置的round()函数
# a = 1.1314 a = 1.0000 a = 1.1267
# b = round(a.2)b = round(a.2)b = round(a.2)
# output b = 1.13 output b = 1.0 output b = 1.13
t1 = round(float(t1), 4)
t2 = round(float(t2), 4)
t3 = round(float(t3), 4)
t4 = round(float(t4), 4)
t5 = round(float(t5), 4)
print(t1, t2, t3, t4, t5)
self.lineEdit_7.setText(str(t1))
self.lineEdit_8.setText(str(t2))
self.lineEdit_9.setText(str(t3))
self.lineEdit_10.setText(str(t4))
self.lineEdit_11.setText(str(t5))
# 下面为记录每次运行的结果
# 英文测试
with open('./testResult/enTestResult.txt', 'a+') as writers:
# 中文测试
# with open('./testResult/testResult.txt', 'a+') as writers:
# 打开文件 ‘a+’ ==a+r(可追加可写,文件若不存在就创建)
if t1 > 0.5:
a = imgNamePath
b = t1
c = 'True'
# 如果要按行写入,我们只需要再字符串开头或结尾添加换行符'\n'
# writers.write(a + '\n')
# 如果想要将多个变量同时写入一行中,可以使用writelines()函数,
# 要求将传入的变量写成一个list:
# writers.write('\n')
# writers.writelines([str(a), ',', str(b), ',', str(c)])
writers.write(str(a) + ',' + str(b) + ',' + str(c) + '\n')
else:
res = 'False'
writers.write(res + '\n')
# return t1, t2, t3, t4, t5

def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "文字图像匹配度检测"))
self.pushButton.setText(_translate("MainWindow", "选择图片"))
self.label.setText(_translate("MainWindow",
"<html><head/><body><p><span style=\" font-size:14pt; font-weight:600;\">图文匹配</span></p></body></html>"))
self.pushButton_2.setText(_translate("MainWindow", "随机抽取中文"))
self.pushButton_3.setText(_translate("MainWindow", "开始检测"))
self.pushButton_5.setText(_translate("MainWindow", "随机抽取英文"))

如果你感觉读后有收获,可以点击下方打赏请作者喝杯咖啡。