3 Star 40 Fork 10

云中有鹿来 / 华为云2019垃圾分类大赛-resnet50

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
prediction.py 3.94 KB
一键复制 编辑 原始数据 按行查看 历史
陈云 提交于 2019-11-27 18:28 . v1.0.2
import tensorflow as tf
import re
import json
from data_process import preprocess_img,preprocess_img_from_Url
from models.resnet50 import ResNet50
from keras.layers import Dense,Dropout,BatchNormalization,GlobalAveragePooling2D
from keras.models import Model
import numpy as np
from keras import regularizers
from tensorflow.python.keras.backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.7
sess = tf.Session(config=config)
# 全局配置文件
tf.app.flags.DEFINE_integer('num_classes', 40, '垃圾分类数目')
tf.app.flags.DEFINE_integer('input_size', 224, '模型输入图片大小')
tf.app.flags.DEFINE_integer('batch_size', 16, '图片批处理大小')
FLAGS = tf.app.flags.FLAGS
h5_weights_path = './output_model/best.h5'
## 增加最后输出层
def add_new_last_layer(base_model,num_classes):
x = base_model.output
x = GlobalAveragePooling2D(name='avg_pool')(x)
x = Dropout(0.5,name='dropout1')(x)
# x = Dense(1024,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc1')(x)
# x = BatchNormalization(name='bn_fc_00')(x)
x = Dense(512,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc2')(x)
x = BatchNormalization(name='bn_fc_01')(x)
x = Dropout(0.5,name='dropout2')(x)
x = Dense(num_classes,activation='softmax')(x)
model = Model(inputs=base_model.input,outputs=x)
return model
# 加载模型
def model_fn(FLAGS):
# K.set_learning_phase(0)
# setup model
base_model = ResNet50(weights="imagenet",
include_top=False,
pooling=None,
input_shape=(FLAGS.input_size, FLAGS.input_size, 3),
classes=FLAGS.num_classes)
for layer in base_model.layers:
layer.trainable = False
# if FLAGS.mode == 'train':
# K.set_learning_phase(1)
model = add_new_last_layer(base_model,FLAGS.num_classes)
# print(model.summary())
# print(model.layers[84].name)
# exit()
# Adam = adam(lr=FLAGS.learning_rate,clipnorm=0.001)
model.compile(optimizer="adam",loss = 'categorical_crossentropy',metrics=['accuracy'])
return model
# 加载封装测试模型
def init_artificial_neural_network(sess):
set_session(sess)
model = model_fn(FLAGS)
model.load_weights(h5_weights_path, by_name=True)
return model
## 测试图片
def prediction_result_from_img(model,imgurl):
# 加载分类数据
try:
with open("./garbage_classify/garbage_classify_rule.json", 'r') as load_f:
load_dict = json.load(load_f)
if re.match(r'^https?:/{2}\w.+$', imgurl):
test_data = preprocess_img_from_Url(imgurl, FLAGS.input_size)
else:
test_data = preprocess_img(imgurl, FLAGS.input_size)
if test_data != 0:
tta_num = 5
predictions = [0 * tta_num]
for i in range(tta_num):
x_test = test_data[i]
x_test = x_test[np.newaxis, :, :, :]
prediction = model.predict(x_test)[0]
# print(prediction)
predictions += prediction
pred_label = np.argmax(predictions, axis=0)
print('-------深度学习垃圾分类预测结果----------')
print(pred_label)
print(load_dict[str(pred_label)])
print('-------深度学习垃圾分类预测结果--------')
return load_dict[str(pred_label)]
else:
print('-------文件读取错误----------')
return False
except Exception as e:
print('发生了异常-prediction:', e)
# if __name__ == "__main__":
# test_img('https://timgsa.baidu.com/timg?image&quality=80&size=b9999_10000&sec=1572815106329&di=bf107149c926f3114c25e74ef3b79275&imgtype=0&src=http%3A%2F%2Fwww.yejs.com.cn%2Fuserfiles%2FDSC05339.JPG')
# test_img('./test_data/new_img1835.jpg')
Python
1
https://gitee.com/likecy/garbage-resnet50-sever.git
git@gitee.com:likecy/garbage-resnet50-sever.git
likecy
garbage-resnet50-sever
华为云2019垃圾分类大赛-resnet50
master

搜索帮助