1 Star 0 Fork 1

guibao233 / fira

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
fira_Collect_offline.py 11.02 KB
一键复制 编辑 原始数据 按行查看 历史
guibao233 提交于 2022-09-22 15:02 . 此部分为fira核心代码
# encoding: utf-8
__author__ = 'Gui'
'''
@Time:2022年6月5日 18:15
@Auth:19级机器人工程 (按拼音排序) 桂源泽 苏琦 颜欢
@File:fira_Collect_offline.py
@IDE:fira项目数据收集程序
@Software: PyCharm
'''
import rospy
import numpy as np
import cv2
from robot import Robot
from geometry_msgs.msg import Twist
import time
import os
rate = 10 # 向ros发布命令频率
v = 0.1 # 前进距离
angular = 0.1 # 旋转角度
threshold_r = 50 # 红二值化阈值 110
threshold_b = 40 # 蓝二值化阈值 110
threshold_g = 30 # 绿二值化阈值 110
class CollectTrainingData(object):
"""
input:
commands and video
k: cmd: control:
1 0 0 0 0 w:前进 u i o 左前 前进 右前
0 1 0 0 0 a:左前 j k l 左转 停止 右转
0 0 1 0 0 d:右前 m , . 左后 后退 右后
0 0 0 1 0 s:停止
0 0 0 0 1 t:冲刺
output:
带有标签的灰度图像集,标签(0, 1, 2, 3 , 4)分别代表(前进, 左转,右转,停止, 冲刺)
每种标签数量上限1000张,像素为H*W = 480×180
"""
def __init__(self):
self.raw_height = 480 # 原始视频高度
self.raw_width = 640 # 原始视频宽度
self.video_width = 480 # 截取图像宽度
self.video_width_save = 480 # 保存图像宽度*3
self.video_height = 180 # 截取图像高度
self.channels = 1 # 通道数量 1
self.NUM = 5 # 分类数量:0, 1, 2, 3, 4
self.range = 300 # 每个分类的图片数
self.data_path = "dataset"
self.saved_file_name = 'labeled_img_data_' + str(int(time.time()))
#控制底盘
self.robot = Robot()
#self.mv = Movement()
# 发布话题相关参数
self.rate_run = rospy.Rate(rate)
self.twist = Twist()
# 创建标签列表
self.k = np.zeros((self.NUM, self.NUM), 'float')
for i in range(self.NUM):
self.k[i, i] = 1
self.collect_image()#开始收集图片
def collect_image(self):
# 初始化数数
total_images_collected = 0
num_list = [0, 0, 0, 0, 0, 0, 0] # 当前各标签存储图片数量
#cap = cv2.VideoCapture(0) # 开启摄像头
images = np.zeros((1, self.video_height * self.video_width), dtype=float)
labels = np.zeros((1, self.NUM), dtype=float)
# Send an action to begin program.
# # 示意准备完毕
print("prepar to continue.............")
while True:
frame = self.robot.get_image()
cv2.imshow("orgin", frame)
resized_height = int(self.video_width * 0.75)
# 计算缩放比例
frame = cv2.resize(frame, (self.video_width, resized_height))
# frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) # 灰度化
# frame = cv2.GaussianBlur(frame, (5, 5), 0) # 高斯模糊
frame = cv2.medianBlur(frame, 3) # 中值滤波
# slice the lower part of a frame
res = frame[90:270, :]#剪切画面
cv2.imshow("review", res)
[aisle_b, aisle_g, aisle_r] = cv2.split(res)
# aisle_b = cv2.adaptiveThreshold(aisle_b,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C,cv2.THRESH_BINARY,3,5)
_, aisle_g = cv2.threshold(aisle_g, threshold_g, 255, cv2.THRESH_BINARY) # 统二值化
# aisle_b = cv2.Canny(aisle_b,40,140)
aisle_b = cv2.adaptiveThreshold(aisle_b,255,cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRESH_BINARY,3,5)
# aisle_r = cv2.adaptiveThreshold(aisle_r,255,cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRESH_BINARY,3,5)
_, aisle_r = cv2.threshold(aisle_r, threshold_r, 255, cv2.THRESH_BINARY) # 统二值化
res_tmp = cv2.vconcat([aisle_b,aisle_g,aisle_r])
res = cv2.resize(res_tmp, (self.video_width, self.video_height), interpolation=cv2.INTER_AREA)
# _, res = cv2.threshold(res, threshold_yuzhi, 255, cv2.THRESH_BINARY) # 统一二值化
cv2.imshow("review2", res)
command = cv2.waitKey(100) & 0xFF # 等待输入按键取后八位
if command == 255:
self.twist.linear.x = 0
self.twist.angular.z = 0
continue
elif command == ord('q'):
print("..............quiting.............")
break
# forward -- 0
elif command == ord('w'):
if num_list[0] < self.range: #小于最大图片数量
num_list[0] += 1
total_images_collected += 1
# self.twist.linear.x = v
# self.twist.angular.z = 0
res = np.reshape(res, [1, -1])
images=np.vstack((images, res)) # 将当前画面按垂直方向堆叠
labels = np.vstack((labels, self.k[0])) #将当前标签按垂直方向堆叠
print("Forward image collect: ", num_list[0])
else:
print("list full!!!")
# self.twist.linear.x = v
# self.twist.angular.z = 0
continue
# forward-left -- 1
elif command == ord('a'):
if num_list[1] < self.range:
num_list[1] += 1
total_images_collected += 1
# self.twist.linear.x = v
# self.twist.angular.z = angular
res = np.reshape(res, [1, -1])
images=np.vstack((images, res))
labels = np.vstack((labels, self.k[1]))
print("Left image collect: ", num_list[1])
else:
print("list full!!!")
# self.twist.linear.x = 0
# self.twist.angular.z = 0
continue
# forward-right -- 2
elif command == ord('d'):
if num_list[2] < self.range:
num_list[2] += 1
total_images_collected += 1
# self.twist.linear.x = v
# self.twist.angular.z = -angular
res = np.reshape(res, [1, -1])
images=np.vstack((images, res))
labels = np.vstack((labels, self.k[2]))
print("Right image collect: ", num_list[2])
else:
print("list full!!!")
# self.twist.linear.x = 0
# self.twist.angular.z = 0
continue
# stop-sign -- 3
elif command == ord('s'):
if num_list[3] < self.range:
num_list[3] += 1
total_images_collected += 1
# self.twist.linear.x = 0
# self.twist.angular.z = 0
res = np.reshape(res, [1, -1])
images=np.vstack((images, res))
labels = np.vstack((labels, self.k[3]))
print("Stop image collect: ", num_list[3])
else:
print("list full!!!")
# self.twist.linear.x = 0
# self.twist.angular.z = 0
continue
# road banner front
elif command == ord('t'):
if num_list[4] < self.range:
num_list[4] += 1
total_images_collected += 1
# self.twist.linear.x = v
# self.twist.angular.z = 0
res = np.reshape(res, [1, -1])
images = np.vstack((images, res))
labels = np.vstack((labels, self.k[4]))
print("rush image collect: ", num_list[4])
else:
print("list full!!!")
# self.twist.linear.x = 0
# self.twist.angular.z = 0
continue
## 控制小车运动辅助拍照
# forward
elif command == ord('i'):
self.twist.linear.x = v
self.twist.angular.z = 0
print("前进")
# forward-left
elif command == ord('u'):
self.twist.linear.x = v
self.twist.angular.z = angular
print("左转")
# forward-right
elif command == ord('o'):
self.twist.linear.x = v
self.twist.angular.z = -angular
print("右转")
# back
elif command == ord(','):
self.twist.linear.x = -v
self.twist.angular.z = 0
print("后退")
# self-left
elif command == ord('j'):
self.twist.linear.x = 0.1*v
self.twist.angular.z = angular
print("左自转")
# self-right
elif command == ord('l'):
self.twist.linear.x = 0.1*v
self.twist.angular.z = -angular
print("右自转")
# stop
elif command == ord('k'):
self.twist.linear.x = 0
self.twist.angular.z = 0
print("停止")
elif command == ord('m'):
self.twist.linear.x = -v
self.twist.angular.z = -angular
print("左后")
# forward-left
elif command == ord('.'):
self.twist.linear.x = -v
self.twist.angular.z = angular
print("右后")
elif num_list[0] == self.range and num_list[1] == self.range and num_list[2] == self.range and num_list[3] == self.range and num_list[4] == self.range and num_list[5] == self.range and num_list[6] == self.range:
#elf.mv.wave_hands()
print("---------All list full!!!----------")
break
# exit(0)
# 向机器人底盘发布数据
self.robot.publish_twist(self.twist)
self.rate_run.sleep()
img = images[1:, :]
lbl = labels[1:, :]
print("image shape:", img.shape)
print("label shape:", lbl.shape)
print("\n")
print("forward images num:", num_list[0])
print("forward left images num:", num_list[1])
print("forward right images num:", num_list[2])
print("stop sign images num:", num_list[3])
# 保存数据
if not os.path.exists(self.data_path):
os.mkdir(self.data_path)
try:
# 保存文件
print(".................saving file...............")
name = self.data_path + '/' + self.saved_file_name + '.npz'
np.savez(name, train=img, train_labels=lbl, num_list=num_list)
print("saving file:", name)
except IOError as e:
print(e)
cv2.destroyAllWindows()
if __name__ == '__main__':
try:
CollectTrainingData()
except KeyboardInterrupt:
pass
Python
1
https://gitee.com/guibao2/fira.git
git@gitee.com:guibao2/fira.git
guibao2
fira
fira
master

搜索帮助