TensorFlow训练石头剪刀布数据集 本文将演示石头剪刀布图片库的神经网络训练过程。石头剪刀布数据集包含了不同的手势图片来自不同的种族、年龄和性别。首先下载石头剪刀布的训练集和测试集importsslfrompathlibimportPathfromurllib.errorimportURLErrorfromurllib.requestimporturlopen DOWNLOAD_DIRPath(D:/mldownload)DOWNLOAD_DIR.mkdir(parentsTrue,exist_okTrue)RPS_URLhttps://storage.googleapis.com/learning-datasets/rps.zipRPS_TEST_URLhttps://storage.googleapis.com/learning-datasets/rps-test-set.zipRPS_ZIPDOWNLOAD_DIR/rps.zipRPS_TEST_ZIPDOWNLOAD_DIR/rps-test-set.zipdefdownload_file(url,destination):ifdestination.exists()anddestination.stat().st_size0:print(fFile already exists, skipping:{destination})returntemp_pathdestination.with_suffix(destination.suffix.part)print(fDownloading:{url})try:responseurlopen(url,timeout120)exceptURLError:contextssl._create_unverified_context()responseurlopen(url,timeout120,contextcontext)withresponse,temp_path.open(wb)asfile:whileTrue:dataresponse.read(1024*1024)ifnotdata:breakfile.write(data)temp_path.replace(destination)size_mbdestination.stat().st_size/1024/1024print(fDownloaded:{destination}({size_mb:.1f}MB))download_file(RPS_URL,RPS_ZIP)download_file(RPS_TEST_URL,RPS_TEST_ZIP)注意根据自己的实际情况设定下载目录。若上述代码无法下载数据集尝试使用浏览器手动下载然后解压下载的数据集。importzipfiledefextract_zip(zip_path,extract_dir):ifnotzip_path.exists():raiseFileNotFoundError(fZip file not found. Run the download cell first:{zip_path})withzipfile.ZipFile(zip_path,r)aszip_ref:bad_filezip_ref.testzip()ifbad_fileisnotNone:raisezipfile.BadZipFile(fZip file looks corrupted at{bad_file}. Delete{zip_path}and download again.)zip_ref.extractall(extract_dir)print(fExtracted:{zip_path}-{extract_dir})extract_zip(RPS_ZIP,DOWNLOAD_DIR)extract_zip(RPS_TEST_ZIP,DOWNLOAD_DIR)检测数据集的解压结果打印相关信息。rock_dirDOWNLOAD_DIR/rps/rockpaper_dirDOWNLOAD_DIR/rps/paperscissors_dirDOWNLOAD_DIR/rps/scissorsforimage_dirin[rock_dir,paper_dir,scissors_dir]:ifnotimage_dir.exists():raiseFileNotFoundError(fDirectory not found:{image_dir}. Run the download and extract cells first.)rock_filessorted(path.nameforpathinrock_dir.iterdir())paper_filessorted(path.nameforpathinpaper_dir.iterdir())scissors_filessorted(path.nameforpathinscissors_dir.iterdir())print(total training rock images:,len(rock_files))print(total training paper images:,len(paper_files))print(total training scissors images:,len(scissors_files))print(rock_files[:10])print(paper_files[:10])print(scissors_files[:10])total training rock images: 840 total training paper images: 840 total training scissors images: 840 [rock01-000.png, rock01-001.png, rock01-002.png, rock01-003.png, rock01-004.png, rock01-005.png, rock01-006.png, rock01-007.png, rock01-008.png, rock01-009.png] [paper01-000.png, paper01-001.png, paper01-002.png, paper01-003.png, paper01-004.png, paper01-005.png, paper01-006.png, paper01-007.png, paper01-008.png, paper01-009.png] [scissors01-000.png, scissors01-001.png, scissors01-002.png, scissors01-003.png, scissors01-004.png, scissors01-005.png, scissors01-006.png, scissors01-007.png, scissors01-008.png, scissors01-009.png]各打印两张石头剪刀布训练集图片%matplotlib inlineimportmatplotlib.pyplotaspltimportmatplotlib.imageasmpimg pic_index2next_rock[rock_dir/fnameforfnameinrock_files[pic_index-2:pic_index]]next_paper[paper_dir/fnameforfnameinpaper_files[pic_index-2:pic_index]]next_scissors[scissors_dir/fnameforfnameinscissors_files[pic_index-2:pic_index]]forimg_pathinnext_rocknext_papernext_scissors:imgmpimg.imread(img_path)plt.imshow(img)plt.axis(off)plt.show()调用TensorFlow的keras进行数据模型的训练和评估。Keras是开源人工神经网络库TensorFlow集成了keras的调用接口可以方便的使用。importtensorflowastfimportkeras_preprocessingfromkeras_preprocessingimportimagefromkeras_preprocessing.imageimportImageDataGenerator TRAINING_DIRD:/mldownload/rps/training_datagenImageDataGenerator(rescale1./255,rotation_range40,width_shift_range0.2,height_shift_range0.2,shear_range0.2,zoom_range0.2,horizontal_flipTrue,fill_modenearest)VALIDATION_DIRD:/mldownload/rps-test-set/validation_datagenImageDataGenerator(rescale1./255)train_generatortraining_datagen.flow_from_directory(TRAINING_DIR,target_size(150,150),class_modecategorical,batch_size126)validation_generatorvalidation_datagen.flow_from_directory(VALIDATION_DIR,target_size(150,150),class_modecategorical,batch_size126)modeltf.keras.models.Sequential([# Note the input shape is the desired size of the image 150x150 with 3 bytes color# This is the first convolutiontf.keras.layers.Conv2D(64,(3,3),activationrelu,input_shape(150,150,3)),tf.keras.layers.MaxPooling2D(2,2),# The second convolutiontf.keras.layers.Conv2D(64,(3,3),activationrelu),tf.keras.layers.MaxPooling2D(2,2),# The third convolutiontf.keras.layers.Conv2D(128,(3,3),activationrelu),tf.keras.layers.MaxPooling2D(2,2),# The fourth convolutiontf.keras.layers.Conv2D(128,(3,3),activationrelu),tf.keras.layers.MaxPooling2D(2,2),# Flatten the results to feed into a DNNtf.keras.layers.Flatten(),tf.keras.layers.Dropout(0.5),# 512 neuron hidden layertf.keras.layers.Dense(512,activationrelu),tf.keras.layers.Dense(3,activationsoftmax)])model.summary()model.compile(losscategorical_crossentropy,optimizerrmsprop,metrics[accuracy])historymodel.fit(train_generator,epochs25,steps_per_epoch20,validation_datavalidation_generator,verbose1,validation_steps3)model.save(rps.h5)Found 2520 images belonging to 3 classes. Found 372 images belonging to 3 classes. Model: sequential _________________________________________________________________ Layer (type) Output Shape Param # conv2d (Conv2D) (None, 148, 148, 64) 1792 max_pooling2d (MaxPooling2D (None, 74, 74, 64) 0 ) conv2d_1 (Conv2D) (None, 72, 72, 64) 36928 max_pooling2d_1 (MaxPooling (None, 36, 36, 64) 0 2D) conv2d_2 (Conv2D) (None, 34, 34, 128) 73856 max_pooling2d_2 (MaxPooling (None, 17, 17, 128) 0 2D) conv2d_3 (Conv2D) (None, 15, 15, 128) 147584 max_pooling2d_3 (MaxPooling (None, 7, 7, 128) 0 2D) flatten (Flatten) (None, 6272) 0 dropout (Dropout) (None, 6272) 0 dense (Dense) (None, 512) 3211776 dense_1 (Dense) (None, 3) 1539 Total params: 3,473,475 Trainable params: 3,473,475 Non-trainable params: 0 _________________________________________________________________ Epoch 1/25 20/20 [] - 53s 3s/step - loss: 1.4751 - accuracy: 0.3599 - val_loss: 1.1369 - val_accuracy: 0.3333 Epoch 2/25 20/20 [] - 49s 2s/step - loss: 1.1303 - accuracy: 0.3603 - val_loss: 1.0965 - val_accuracy: 0.5108 Epoch 3/25 20/20 [] - 47s 2s/step - loss: 1.0863 - accuracy: 0.4032 - val_loss: 0.9790 - val_accuracy: 0.3978 Epoch 4/25 20/20 [] - 50s 2s/step - loss: 1.0418 - accuracy: 0.5139 - val_loss: 0.8253 - val_accuracy: 0.7258 Epoch 5/25 20/20 [] - 50s 2s/step - loss: 0.8743 - accuracy: 0.6087 - val_loss: 0.4759 - val_accuracy: 0.9651 Epoch 6/25 20/20 [] - 48s 2s/step - loss: 0.8080 - accuracy: 0.6345 - val_loss: 0.6926 - val_accuracy: 0.6183 Epoch 7/25 20/20 [] - 46s 2s/step - loss: 0.6538 - accuracy: 0.7103 - val_loss: 0.2193 - val_accuracy: 0.9785 Epoch 8/25 20/20 [] - 46s 2s/step - loss: 0.5827 - accuracy: 0.7579 - val_loss: 0.2920 - val_accuracy: 0.9731 Epoch 9/25 20/20 [] - 45s 2s/step - loss: 0.4396 - accuracy: 0.8286 - val_loss: 0.0803 - val_accuracy: 1.0000 Epoch 10/25 20/20 [] - 47s 2s/step - loss: 0.3461 - accuracy: 0.8560 - val_loss: 0.3216 - val_accuracy: 0.7634 Epoch 11/25 20/20 [] - 45s 2s/step - loss: 0.3198 - accuracy: 0.8730 - val_loss: 0.0706 - val_accuracy: 0.9651 Epoch 12/25 20/20 [] - 45s 2s/step - loss: 0.2977 - accuracy: 0.8861 - val_loss: 0.0884 - val_accuracy: 0.9651 Epoch 13/25 20/20 [] - 47s 2s/step - loss: 0.2832 - accuracy: 0.8952 - val_loss: 0.0391 - val_accuracy: 0.9839 Epoch 14/25 20/20 [] - 43s 2s/step - loss: 0.1713 - accuracy: 0.9353 - val_loss: 0.0592 - val_accuracy: 0.9758 Epoch 15/25 20/20 [] - 44s 2s/step - loss: 0.2972 - accuracy: 0.8913 - val_loss: 0.1070 - val_accuracy: 0.9839 Epoch 16/25 20/20 [] - 48s 2s/step - loss: 0.1306 - accuracy: 0.9575 - val_loss: 0.0549 - val_accuracy: 0.9785 Epoch 17/25 20/20 [] - 46s 2s/step - loss: 0.1886 - accuracy: 0.9226 - val_loss: 0.0500 - val_accuracy: 0.9866 Epoch 18/25 20/20 [] - 45s 2s/step - loss: 0.1101 - accuracy: 0.9615 - val_loss: 0.0518 - val_accuracy: 0.9651 Epoch 19/25 20/20 [] - 48s 2s/step - loss: 0.1343 - accuracy: 0.9556 - val_loss: 0.0105 - val_accuracy: 1.0000 Epoch 20/25 20/20 [] - 45s 2s/step - loss: 0.1349 - accuracy: 0.9528 - val_loss: 0.2117 - val_accuracy: 0.8952 Epoch 21/25 20/20 [] - 50s 2s/step - loss: 0.0918 - accuracy: 0.9687 - val_loss: 0.0844 - val_accuracy: 0.9651 Epoch 22/25 20/20 [] - 50s 2s/step - loss: 0.2075 - accuracy: 0.9317 - val_loss: 0.0403 - val_accuracy: 0.9839 Epoch 23/25 20/20 [] - 51s 3s/step - loss: 0.0937 - accuracy: 0.9675 - val_loss: 0.7548 - val_accuracy: 0.7070 Epoch 24/25 20/20 [] - 46s 2s/step - loss: 0.0861 - accuracy: 0.9694 - val_loss: 0.1306 - val_accuracy: 0.9435 Epoch 25/25 20/20 [] - 47s 2s/step - loss: 0.1002 - accuracy: 0.9655 - val_loss: 0.0382 - val_accuracy: 0.9839ImageDataGenerator是Keras中图像预处理的类经过预处理使得后续的训练更加准确。Sequential定义了序列化的神经网络封装了神经网络的结构有一组输入和一组输出。可以定义多个神经层各层之间按照先后顺序堆叠前一层的输出就是后一层的输入通过多个层的堆叠构建出神经网络。神经网络两个常用的操作卷积和池化。由于图片中可能包含干扰或者弱信息使用卷积处理此处的Conv2D函数使得我们能够找到特定的局部图像特征如边缘。此处使用了3X3的滤波器通常称为垂直索伯滤波器。而池化此处的MaxPooling2D的作用是降低采样因为卷积层输出中包含很多冗余信息。池化通过减小输入的大小降低输出值的数量。详细的信息可以参考知乎回答“如何理解卷积神经网络CNN中的卷积和池化”。更多的卷积算法参考Github Convolution arithmetic。Dense的操作即全连接层操作本质就是由一个特征空间线性变换到另一个特征空间。Dense层的目的是将前面提取的特征在dense经过非线性变化提取这些特征之间的关联最后映射到输出空间上。Dense这里作为输出层。完成模型训练之后我们绘制训练和验证结果的相关信息。importmatplotlib.pyplotasplt acchistory.history[accuracy]val_acchistory.history[val_accuracy]losshistory.history[loss]val_losshistory.history[val_loss]epochsrange(len(acc))plt.plot(epochs,acc,r,labelTraining accuracy)plt.plot(epochs,val_acc,b,labelValidation accuracy)plt.title(Training and validation accuracy)plt.legend(loc0)plt.figure()plt.show()Figure size 640x480 with 0 Axes利用生成了模型我们可以运行实际中的例子例如上传石头剪头布的图片进行推测使用model.predict。这里不做展开后续我们利用Tensorflow Lite进行Android APP开发时可以很自然的利用手机自带的摄像头或者图片库进行图片输入。参考文献TensorFlow APImatplotlib apiTensorFlow rock-paper_scissors dataset