不知什么时候起,我被读者惯出这个毛病。或者说,是我培养出读者这个好习气。你评论问一句,我就来一篇万字长文。往前翻翻,最近几篇都是这样。我只写大家想看的。
之前给大家安利过,不管什么平台都能调用AI能力,是拿着YOLOv8举的例子。我感谢上面那些朋友,他们能提出这些问题,解释是真的有需求,而且动手操作了。你写的文章有人看,这叫被心腹赏识。
因此,我打算详细拆解YOLO的导出,并以tflite格式的天生、导入,以及在移动真个详细代码利用为例,给上面的问题一个答案。同时,见告大家,平台只是一个环境,理解事理可交融贯通,透穿平台。
一、模型转换
故事开始了。
我媳妇网购东西后,喜好比价。她买俩吹风机,一个叫①号,一个叫②号,到货了也每天去平台看看详情页。于是,我给她做了一款Android客户端,让她一扫描就进详情页,不用再到订单里去找。
这里面采取了YOLOv8的目标检测技能,先演习天生.pt权重文件,再导出为.tflite模型文件,末了放入Android项目中实现检测功能。下面来看看效果,检测速率是毫秒级别的。
标注和演习方法,参考之前文章《用YOLOv8一站式办理图像分类、检测、分割》,这是前置知识。本日的重点,从天生的.pt权重文件开始。
先验证一下我的best_num.pt文件,它会去识别test文件下图片里的①号、②号。
from ultralytics import YOLOmodel = YOLO("best_num.pt")model.predict(source="test", save=True, save_txt=True)
好,没问题。下面转化为.tflite格式的文件。
from ultralytics import YOLOmodel = YOLO('best_num.pt') model.export(format='tflite')
把稳,我可点运行了啊!
接下来你会看到好永劫光的转圈。由于好多库没有安装。
如读者反馈,这个过程可能会报错,报错的缘故原由得看详细的缺点信息。大多数缺点和环境冲突有关,比如你原来有个2.0,此时它自动去安装个3.2,可能就会产生缺点。因此,我建议你整一个全新的虚拟环境去做。别怕麻烦,给每个项目配一个专属空间,会减少很多不必要的麻烦。
如果是安装好了,那还是很快的。只须要16.5秒。
我们从日志中可以看到,它经历了一番弯曲的转换。
1.2 格式转换的路线起初,它是一个PyTorch模型的.pt文件,名称叫best_num.pt。然后,它被转换为onnx格式的best_num.onnx。
ONNX的全称是Open Neural Network Exchange(开放式神经网络交流)。它是由微软、Facebook、IBM等科技公司在2017年共同发起的一种机制,可以实现不同深度学习框架(如PyTorch、TensorFlow、Caffe2等)模型之间的相互转换。因此,onnx格式是必经之路。
随后,又用onnx2tf工具,以命令行的办法将ONNX模型转换为TensorFlow SavedModel格式,并以best_num_saved_model文件夹保存。这个格式是序列化TensorFlow模型用的。
紧接着,启动TensorFlow Lite的导出过程,将TensorFlow SavedModel模型转换为best_num_float32.tflite的TensorFlow Lite格式。
呜呼呀!
我这5.9MB的.pt文件,终极居然被转为11.6MB的.tflite。这显然弗成,在App里太大了!
其余,我看到saved_model文件下有很多.tflite文件。它们的名字还带着数字:float32.tflite、float16.tflite……这是什么情形?
float32.tflite:全精度模型。参数都以32位浮点数(float32)存储。精度高,运行速率相对会慢。float16.tflite:半精度模型。参数都以16位浮点数(float16)存储。大小是全精度模型的一半,运行速率会快一些。在Android终端设备上的推理,我们须要的是更快,而非特殊精确。由于如果哀求精确,不考虑韶光,我上传到做事端去处理好不好。
说的很有道理,我们可以通过量化来改进大小和速率问题。只须要加一个参数model.export(……, int8=True),再运行一下。
ONNX: export success 2.2s, saved as 'best_num.onnx' (11.6 MB)TensorFlow SavedModel: running 'onnx2tf -i "best_num.onnx" -o "best_num_saved_model" -nuo --verbosity info -oiqt -qt per-tensor'TensorFlow SavedModel: export success 212.0s, saved as 'best_num_saved_model' (38.6 MB)TensorFlow Lite: starting export with tensorflow 2.13.0...TensorFlow Lite: export success 0.0s, saved as 'best_num_saved_model\best_num_int8.tflite' (3.0 MB)Export complete (213.7s)
这次耗时长,用了213.7秒,终极导出模型的大小为3.0MB。我满意了,这个大小放到app才得当。
新增的文件如下:
它给出一个best_num_int8.tflite作为最优选择,这是什么情形?
1.3 模型的量化它叫int8量化模型。此模型被量化时,它将浮点数值映射到8-bit的整数范围,并保存了映射关系。当模型进行推理时,这些整数可以被重新阐明为靠近原始的浮点数值。
量化技能,能在减小模型大小和提高实行速率的同时,仍旧保持相对高的精度。
转换成功喽。
下面我们就来拆解它,理解如何剖析,我们给它传什么数据,以及它又会返给我们若何的结果!
上面的best_num_int8.tflite模型是我们自己演习并转化的。因此,我们理解它的构造和出入参数。
现在换一个故事,有人给了一个xxx.tflite,让你去调用。此时你该如何做呢?实在,用代码就可以剖析出来很多有用的信息。
以下操作,用Python和Android都可以实现。鉴于Python简洁,以是先用它快速演示效果,后面我们还会用Android再做一遍。
假设best_num_int8.tflite便是那个xxx.tflite文件,我们用代码来将它阅读一下。
2.1 Interpreter阐明器对付tflite文件的解析,TensorFlow供应了一个Interpreter类。
import tensorflow as tfinterpreter = tf.lite.Interpreter(model_path='xxx.tflite')interpreter.allocate_tensors()print(interpreter.get_tensor_details())
通过interpreter的get_tensor_details,可以获取全体网络构造的信息。
那么,你获取到这些构造信息,有什么用呢?
2.2 网络构造与层诶,它定义了全体数据流转的格式和操作。这就相称于一套操作步骤,讲述了在哪个步骤会把什么做若何的处理。我们把模型比作馒头生产流水线机器,那第一步是放入面粉,随后对面粉加水,和面,揉面,拉长,切割,终极产出馒头。
因此,一旦我们理解了这套馒头机的流程。我们就能清楚地知道,在机器入口该按照若何的频率倒入多少量的面粉,然后在出口能收成什么形状、多少重量的馒头。
实在这与把图片传给模型,它见告你里面有什么物体类似,都是一个加工致顿的过程。
如果把模型做的大略一些,可以是下面这样:
虽然它不智能,但我相信这有助于你更好地理解构造。
除了用代码读取,很多网站也可以浏览模型文件的构造。比如这个网站 https://netron.app/ 。
这解释这些文件都是公开可读的,并没有什么分外加密。
弱水三千,只取一瓢。模型百层,我只关心输入、输出(犹如面粉与馒头)。倘若读出它们,倒也不难,一行代码可成。
# 获取输入层input_details = interpreter.get_input_details()# 获取输出层output_details = interpreter.get_output_details()
打印一下看看:
input_details:[{'name': 'serving_default_images:0', 'index': 0, 'shape': array([ 1, 640, 640, 3]), 'shape_signature': array([ 1, 640, 640, 3]), 'dtype': numpy.float32, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]output_details: [{'name': 'PartitionedCall:0', 'index': 410, 'shape': array([ 1, 6, 8400]), 'shape_signature': array([ 1, 6, 8400]), 'dtype': numpy.float32, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
这里面,英语单词翻译过来便是阐明。比如sparsity parameters是有关稀疏性参数的信息。这个例子中,值都为空,表示没有进行稀疏性优化。
2.3 形状和数据类型本日我们只关注2项:
shape:描述该层数据的形状。输入层的形状是[1, 640, 640, 3],这意味着该层吸收一个四维数组,第一维是样本数(批次大小,几张图片),后面三个维度分别表示一张图片的高度、宽度和颜色通道数。对付输入层数据形式的组织,一样平常不难。难的是对输出层的剖析和处理,后面一章我们会重点拆解。dtype:表示该层数据的数据类型。我们看到两个层的dtype都是numpy.float32,也便是32位的浮点数。好了,tflite文件我们剖析完了。我也知道,你轻微有点懂,但还不至于全懂。看完后面,可能会有所改进。
三、用Python调用模型文件下面,我们先把输入层的数据组织好,然后调用文件试试看。
我们这个模型的输入层的形状是[1, 640, 640, 3]。
实在展开是这样:
图片列表
图片数据
图片1
640640个像素点,每个点用(R,G,B)3种色值表示
图片2
640640个像素点,每个点用(R,G,B)3种色值表示
图片……
640640个像素点,每个点用(R,G,B)3种色值表示
我们先忽略多个图片,只考虑一张图片的情形,这样能大略些。
来读这么一张图片。
3.1 图片的读取
我们读取一下它的数值:
import cv2image = cv2.imread('num.jpg')print(image.shape)
cv2.imread会把一张图片读取成矩阵数据。image.shape是数据的形状,由3部分构成:图片高/矩阵行数,图片宽/矩阵列数,色彩通道数。
这张图片的shape打印出来是(2162, 2883, 3)。这表示图片尺寸为2883×2162,通道数为3。
我们如果在它的表面套一层,加一个[],它就可以变成(1, 2162, 2883, 3)。但是,现在我们首先要把它的尺寸变为(640, 640, 3),由于输入层的格式是(1, 640, 640, 3)。
到这里,你或许有点质疑,这640是怎么来的,谁规定的?我用960行弗成?
兄弟,弗成的。你没法用小麦粒充当面粉往馒头机里放。
3.2 输入数据的预处理真想要追根溯源,得说你当初用YOLOv8演习时,只实行了一句model.train(data="num.yaml", epochs=80),并没有做其他设置。而未设置的,会走一个默认配置,这个配置在Lib\site-packages\ultralytics\cfg下,名字叫default.yaml,里面就有一个imgsz便是640。一个值,表示640×640是正方形,两个值可以设置宽与高。
看我文章,跟听书似的,能涨不少周边知识。
我们要检测的图片,可能来自摄像头,可能来自用户上传,这个咱们不能限定。我们要做的是将图片修正成640×640。
def pre_img(image): height, width, _ = image.shape # 等比例缩放 if height > width: new_height = 640 new_width = int(640 width / height) else: new_width = 640 new_height = int(640 height / width) image_resized = cv2.resize(image, (new_width, new_height)) # 创建一个640640的白色背景图像 background = np.ones((640, 640, 3), dtype=np.uint8) 255 # 将缩放后的图像粘贴到背景图像的中央位置 start_x = (640 - new_width) // 2 start_y = (640 - new_height) // 2 background[start_y:start_y+new_height, start_x:start_x+new_width] = image_resized return backgroundresize_image = pre_img(image)
为了凑一幅640×640的图像,我们采取的处理办法是:不管图片大小,先让它顶着边放大或者缩小到640×640的框里,然后背景设为白色。
此时打印resize_image.shape就看到了久违的(640, 640, 3)。
把稳,要开始调用模型了!
调用很大略,代码加注释,担保你一看就会!
# 单张图片数据转为浮点型input_image_f32 = resize_image.astype(dtype=np.float32)/ 255# 表面包一层[]组成[1, 640, 640, 3]input_data = np.expand_dims(input_image_f32, axis=0)# 将input_data数据塞给输入层,从索引找到input_index = input_details[0]['index'] # 输入的索引interpreter.set_tensor(input_index, input_data)# 跑一跑interpreter.invoke()# 将输出层的数据拿出来,从索引确定输出层output_index = output_details[0]['index'] # 输出层的索引detect_scores = interpreter.get_tensor(output_index)print(detect_scores.shape, detect_scores)
末了来数据了,便是那个detect_scores。
3.4 输出数据剖析detect_scores.shape:(1, 6, 8400)detect_scores: array([[[7.9783527e-03, 2.5762582e-02, 3.6012750e-02, ..., 8.1423753e-01, 8.3901447e-01, 9.1082019e-01], ... 1.8597868e-03, 1.8911671e-03, 1.9312450e-03]]], dtype=float32)
输出数据的形状是(1, 6, 8400)。
看输入数据的形状,有履历的老CV师傅,尚且能猜到是图片数据。但现在看这个输出数据的形状,就真的须要你对YOLOv8算法轻微理解才行喽。
我给大家阐明一下,这些维度都代表什么。
阐明之前,得再往回倒历史,YOLO是You Only Look Once的简称。这种算法,只须要在图上扫一遍就够了。由于有的算法,须要对图片扫描多遍才能实现目标检测。
于是,YOLO会设置一个最小网格作为基本单位,划分出非常多的大大小小的框。然后检测这些框里面是否有目标,以及是某种物体类型的可行性。
我只演习标注了①②两类目标,以是分类数量是2。
下面就随意马虎理解这个模型的输出啦。
3.4.1 输出层格式解析维度数值
阐明
1
图片批次大小,有几张图片。1代表一张
8400
一张图中划分出的8400个小区域
6
6个数代表 (中央点x, 中央点y, 宽度w, 高度h, 分类1的得分, 分类2的得分)
我们仍旧只关注一张图片,并且把数据处理一下。
# 降维 (1, 6, 8400) -> (6, 8400)detect_score = np.squeeze(detect_scores) # 转换 (6, 8400) -> (8400, 6)output_data = np.transpose(detect_score)
打印一个数据看看print(output_data[0]),输出为:
[0.01971355 0.01480704 0.04122782 0.03146162 0.00014795 0.0001379 ]
这是8400个框中第1个框的数据,6位数便是上面表格里对应的6个含义。
我想画一下这些框。但是可以想象,画面肯定就糊了。咱们这样,只画出种别概率大于某个数值的框。
# 前4个是矩形框 x, y, width, heightboxes = output_data[:, :4] # 后2个是①的概率,②的概率scores = output_data[:, 4:]# 打算每个边界框最高的得分max_scores = np.max(scores, axis=1)# 找到知足一定准确率的框【修正点在这里】keep = max_scores >= 0.6# 得到符合条件的边界框和得分filtered_boxes, filtered_scores = boxes[keep], scores[keep]rimge = resize_image.copy()(height, width) = rimge.shape[:2]for i, box in enumerate(filtered_boxes): x,y,w,h = box # x,y 是中央点的坐标,而且是占宽高的百分比 x1,y1 = (x-w/2)width, (y-w/2)height x2,y2 = (x+w/2)width, (y+w/2)height cv2.rectangle(rimge, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 1)
下图是我画出概率大于0.01和0.601的框,可以看出差异还是挺明显的。
彷佛我们已经从输出数据,找到了目标和位置。
等会儿……彷佛还有一个问题,框的重复情形比较严重。产生的缘故原由便是前面说的8400个框。
3.4.2 NMS非极大值抑制看下图,这3个区域,都是合格的网格,而且也都检测到了目标。你不能说它们谁有错!
这可……怎么办?
此时,你再看开篇那张图,有读者说“是不是NMS不包含?”。我说他们真的有需求,而且存心看了是有缘故原由的。NMS全称是Non-Maximum Suppression,换成中国话便是“非极大值抑制”。
普通来讲,便是打消同类弱者,因此叫非极大值抑制。好比IT界要选出各个开拓措辞的代表人物,来了1000多口子,300多Java,600多PHP。大家一比拟,啊,都是干Java的,都搞多并发,留一个最好的,剩下的多并发走人。那边有两个人一比拟,你是Java,我是PHP,咱们是两类人,没冲突,都留下。末了,肯定就剩下最具有代表性的人了。
我们选用哪些个框的方案,也是同样的道理。技能实现上,就用到了IoU。不是I LOVE U啊,是IoU。全称是intersection over union,便是……你甭管叫啥。我见告你怎么处理,上代码。
def iou(box1, box2): # 打算交集区域的坐标 x1 = max(box1[0], box2[0]) y1 = max(box1[1], box2[1]) x2 = min(box1[2], box2[2]) y2 = min(box1[3], box2[3]) # 打算交集区域的面积 inter_area = max(0, x2 - x1 + 1) max(0, y2 - y1 + 1) # 打算两个边界框的面积 box1_area = (box1[2] - box1[0] + 1) (box1[3] - box1[1] + 1) box2_area = (box2[2] - box2[0] + 1) (box2[3] - box2[1] + 1) # 打算IoU iou = inter_area / float(box1_area + box2_area - inter_area) return iou
这个iou方法的输出,实现的是两个矩形框的交集除以并集。先求出box1、box2的面积,再求出box1和box2重合的面积。末了,用重合面积除以两个框合占的面积。
来一个图就明白啦。
实在便是重合度,0表示不重合,1表示完备一样,0.6表示重叠60%。
那下面,我们就去烦闷……不是,去抑制非极大值就行啦。
# 抑制非极大值方法def non_max_suppression(boxes, scores, threshold=0.8): # 创建一个用于存储保留的边界框的列表 keep = [] # 对得分进行排序 order = scores.argsort()[::-1] # 循环直到所有边界框都被检讨 while order.size > 0: # 将当前最大得分的边界框添加到keep中 i = order[0] keep.append(i) # 打算剩余边界框与当前边界框的IoU ious = np.array([iou(boxes[i], boxes[j]) for j in order[1:]]) # 找到与当前边界框IoU小于阈值的边界框 inds = np.where(ious <= threshold)[0] # 更新order,只保留那些与当前边界框IoU小于阈值的边界框 order = order[inds + 1] return keep# 打算每个边界框的最高得分max_scores = np.max(filtered_scores, axis=1)# 进行处理keep = non_max_suppression(filtered_boxes, max_scores)# 末了留下的候选框final_boxes = filtered_boxes[keep]final_scores = filtered_scores[keep]# 目标的索引indexs = np.argmax(final_scores, axis=1)
上面代码,将这些高质量的候选框,先按照得分进行排序,然后拿最高分跟其他候选比拟。凡是重合度高的,去掉,重合率低的,保留。这个操作就实现了一山不容二虎。
3.5 呈现终极结果我们把之前画框的代码,轻微改动一下。
for i, box in enumerate(final_boxes): …… color_v = (0, 0, 255) if indexs[i] == 0 else (255, 0, 0) cv2.rectangle(rimge, (int(x1), int(y1)), (int(x2), int(y2)), color_v, 2)
加了一个判断,如果是第①个种别用赤色,第②个种别用蓝色。
运行效果如下:
怎么样,我们用.tflite格式完成了目标检测。这与在PyTorch下的.pt文件是一样的效果。
那个读者问,是不是不包含nms?兄弟,有很多成熟的类库可以一句话调用。但是,退一万步讲,就算咱用原生代码自己去写一套,也没有多少行代码。
以是我讲事理很主要,平台只是一个媒介。
下面,咱们就前往Android的天下,再去实现这一套流程。
四、用Android调用模型文件首先声明,存在比我下面讲的,还要大略的实现方法。这个我是知道的。
比如多导入以下两个包,可以很方便地处理关于模型加载,图像与数据转换,乃至NMS的问题。那样,没几行代码。
implementation 'org.tensorflow:tensorflow-lite-support:0.3.0'implementation 'org.tensorflow:tensorflow-lite-task-vision:0.3.0'
但是,我吹了牛了,我说事理可以不受平台限定。因此,我只导入基本的tensorflow-lite包,用来加载tflite文件。其他全用Java代码来写(Kotlin也一样)。
implementation 'org.tensorflow:tensorflow-lite:2.5.0'
4.1 加载模型并推理
首先,build.gradle导入上面最基本的tensorflow-lite包。然后,将我们的best_num_int8.tflite文件,拷贝到assets文件下。
我的文件构造如下所示:
个中,DetectTool.java是我自己写的一个检测工具类,卖力加载tflite模型,处理图片的缩放,以及剖析模型输出层的数据。NonMaxSuppression.java也是自己手敲的一个处理非极大值抑制的算法类。
首先,加载tflite文件。
import org.tensorflow.lite.Interpreter;public class DetectTool { // 从Assets下加载.tflite文件 private static MappedByteBuffer loadModelFile(Context context, String fileName) throws IOException { AssetFileDescriptor fileDescriptor = context.getAssets().openFd(fileName); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } // 构建Interpreter,这是tflite文件的阐明器 public static Interpreter getInterpreter(Context context){ Interpreter.Options options = new Interpreter.Options(); options.setNumThreads(4); Interpreter interpreter = null; try { interpreter = new Interpreter(loadModelFile(context, "best_num_int8.tflite"), options); } catch (IOException e) { throw new RuntimeException("Error loading model file.", e); } return interpreter; }}
把稳,实行这一步时,须要在build.gradle中配置没关系缩.tflite文件(默认是压缩的)。
android { // 新增:没关系缩tflite文件 aaptOptions { noCompress "tflite" }
此时,你就可以在Activity中利用Interpreter了。
// 构建阐明器Interpreter interpreter = DetectTool.getInterpreter(this);// 将要处理的Bitmap图像缩放为640×640Bitmap resize_bitmap = resizeBitmap(bitmap, 640);// 转换为输入层(1, 640, 640, 3)构造的float数组float[][][][] input_arr = bitmapToFloatArray(resize_bitmap);// 构建一个空的输出构造float[][][] outArray = new float[1][6][8400];// 运行阐明器,input_arr是输入,它会将结果写到outArray中interpreter.run(input_arr, outArray);
你仍旧可以用interpreter的各种get方法获取输入输出的层信息。但是,基于前面我们已经理解了它的构造,因此现在可以直接构建对应的构造。
4.2 输入预处理详解个中,resizeBitmap方法与bitmapToFloatArray方法是自己写的。
resizeBitmap用于图片尺寸缩放。
public static Bitmap resizeBitmap(Bitmap source, int maxSize) { int outWidth; int outHeight; int inWidth = source.getWidth(); int inHeight = source.getHeight(); if(inWidth > inHeight){ outWidth = maxSize; outHeight = (inHeight maxSize) / inWidth; } else { outHeight = maxSize; outWidth = (inWidth maxSize) / inHeight; } Bitmap resizedBitmap = Bitmap.createScaledBitmap(source, outWidth, outHeight, false); Bitmap outputImage = Bitmap.createBitmap(maxSize, maxSize, Bitmap.Config.ARGB_8888); Canvas canvas = new Canvas(outputImage); canvas.drawColor(Color.WHITE); int left = (maxSize - outWidth) / 2; int top = (maxSize - outHeight) / 2; canvas.drawBitmap(resizedBitmap, left, top, null); return outputImage;
bitmapToFloatArray是构建输入层的数据格式。
public static float[][][][] bitmapToFloatArray(Bitmap bitmap) { int height = bitmap.getHeight(); int width = bitmap.getWidth(); // 初始化一个float数组 float[][][][] result = new float[1][height][width][3]; for (int i = 0; i < height; ++i) { for (int j = 0; j < width; ++j) { // 获取像素值 int pixel = bitmap.getPixel(j, i); // 将RGB值分离并进行标准化(假设你须要将颜色值标准化到0-1之间) result[0][i][j][0] = ((pixel >> 16) & 0xFF) / 255.0f; result[0][i][j][1] = ((pixel >> 8) & 0xFF) / 255.0f; result[0][i][j][2] = (pixel & 0xFF) / 255.0f; } } return result;}
Bitmap是图片,可以是一张本地图片文件,也可以是从相机的预览回调传来的每一帧图像。
只要通过interpreter.run(input_arr, outArray)后,outArray中就有了却果数据,它的形状便是我们熟习的那个(1, 6, 8400)。
用python时,我们全程是手写算法。在Java中,一样可以做到。
4.3 输出数据的处理// 取出(1, 6, 8400)中的(6, 8400)float[][] matrix_2d = outArray[0];// (6, 8400)变为(8400, 6)float[][] outputMatrix = new float[8400][6];for (int i = 0; i < 8400; i++) { for (int j = 0; j < 6; j++) { outputMatrix[i][j] = matrix_2d[j][i]; }}float threshold = 0.6f; // 种别准确率筛选float non_max = 0.8f; // nms非极大值抑制ArrayList<float[]> boxes = new ArrayList<>();ArrayList<Float> maxScores = new ArrayList();for (float[] detection : outputMatrix) { // 6位数中的后两位是两类的置信度 float[] score = Arrays.copyOfRange(detection, 4, 6); float maxValue = score[0]; float maxIndex = 0; for(int i=1; i < score.length;i++){ if(score[i] > maxValue){ // 找出最大的一项 maxValue = score[i]; maxIndex = i; } } if (maxValue >= threshold) { // 如果置信度超过60%则记录 detection[4] = maxIndex; detection[5] = maxValue; boxes.add(detection); // 筛选后的框 maxScores.add(maxValue); // 筛选后的准确率 }}
这段实现和python差异很大。由于原生Java代码在处理矩阵上基本全靠循环。它不像python可以一句话获取矩阵的横向均匀值、竖向最大值。
因此,我将那6位数中的detection[4]设置为最大值的分类索引,detection[5]存储最大值的分值。
到这里,我们就获取到了分类概率大于60%的所有备选框。这时同样会涌现框重复的情形。须要做一个NMS。
public class NonMaxSuppression { public static float iou(float[] box1, float[] box2) { float x1 = Math.max(box1[0], box2[0]); float y1 = Math.max(box1[1], box2[1]); float x2 = Math.min(box1[2], box2[2]); float y2 = Math.min(box1[3], box2[3]); float interArea = Math.max(0, x2 - x1 + 1) Math.max(0, y2 - y1 + 1); float box1Area = (box1[2] - box1[0] + 1) (box1[3] - box1[1] + 1); float box2Area = (box2[2] - box2[0] + 1) (box2[3] - box2[1] + 1); return interArea / (box1Area + box2Area - interArea); } public static List<float[]> nonMaxSuppression(List<float[]> boxes, List<Float> scores, float threshold){ List<float[]> result = new ArrayList<>(); while (!boxes.isEmpty()) { int bestScoreIdx = scores.indexOf(Collections.max(scores)); float[] bestBox = boxes.get(bestScoreIdx); result.add(bestBox); boxes.remove(bestScoreIdx); scores.remove(bestScoreIdx); List<float[]> newBoxes = new ArrayList<>(); List<Float> newScores = new ArrayList<>(); for (int i = 0; i < boxes.size(); i++) { if (iou(bestBox, boxes.get(i)) < threshold) { newBoxes.add(boxes.get(i)); newScores.add(scores.get(i)); } } boxes = newBoxes; scores = newScores; } return result; }}
iou的打算险些和python的处理一样。nonMaxSuppression则根据Java语法特性,变革了一些。
但是,事理是不变。都是先按照分数排名,然后忽略和高分重合度高的,收录重合率低的。
末了的result是终极结果,它是一个列表,每个子项里面6个数,分别是:中央点x、中央点y、框的宽width、框的高height、属于哪一类class_index、置信概率值。
便是这样,Android也成功实现了。你在程序里调用就可以。
五、小结5.1 源码分享我已经将python代码、Java两个类,以及我的pt、tflite文件,还有测试图片,上传到Github上了。希望得到大家的辅导 https://github.com/hlwgy/yolo2tflite。
本篇文章的代码有点多,能读到这里的,都是好哥们。你可以连续提问,我会连续写文章回答。
我以为AI技能并不难,而且它间隔现实生活也不远。希望我们能一起去探索并运用它。
我是头条@ITF男孩,一个从事人工智能的程序员。