Tensorflow2.0笔记 - ResNet实践

        本笔记记录使用ResNet18网络结构,进行CIFAR100数据集的训练和验证。由于参数较多,训练时间会比较长,因此只跑了10个epoch,准确率还没有提升上去。

import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__


#关于ResNet的描述,可以参考如下链接:
#https://blog.csdn.net/qq_39770163/article/details/126169080
#代码基于ResNet18结构,有少许不一样
class BasicBlock(layers.Layer):
    def __init__(self, filter_num, strides = 1):
        super(BasicBlock, self).__init__()
        #卷积层1
        self.conv1 = layers.Conv2D(filter_num, (3,3), strides = strides, padding='same')
        #BN层
        self.bn1 = layers.BatchNormalization()
        #Relu层
        self.relu = layers.Activation('relu')

        #卷积层2,BN层2,
        self.conv2 = layers.Conv2D(filter_num, (3,3), strides = 1, padding='same')
        self.bn2 = layers.BatchNormalization()

        #Shortcut
        if strides != 1:
            #如果strides不为1,需要下采样
            self.downsample = Sequential()
            self.downsample.add(layers.Conv2D(filter_num, (1,1), strides=strides))
        else:
            #strides为1, 直接返回原始值即可
            self.downsample = lambda x:x
        
    def call(self, inputs, training = None):
        #经过第一个卷积层,BN和Relu
        out = self.conv1(inputs)
        out = self.bn1(out)
        out = self.relu(out)

        #经过第二个卷积层
        out = self.conv2(out)
        out = self.bn2(out)

        #Shortt处理,out和输入相加
        identity = self.downsample(inputs)
        output = layers.add([out, identity])
        #再经过一个relu
        output = tf.nn.relu(output)
        return output

class ResNet(keras.Model):
    #layer_dims表示对应位置的ResBlock包含了几个BasicBlock
    #比如[2,2,2,2] => 总共4个ResBlock,每个ResBlock包含两个BasicBlock
    #num_classes表示输出的类别的个数
    def __init__(self, layer_dims, num_classes=100):
        super(ResNet, self).__init__()
        #预处理单元
        self.stem = Sequential([layers.Conv2D(64, (3,3), strides=(1,1)),
                                layers.BatchNormalization(),
                                layers.Activation('relu'),
                                layers.MaxPool2D(pool_size=(2,2), strides=(1,1), padding='same')
                               ])
        #创建中间ResBlock层
        self.layer1 = self.buildResBlock(64, layer_dims[0])
        self.layer2 = self.buildResBlock(128, layer_dims[1], strides=2)
        self.layer3 = self.buildResBlock(256, layer_dims[2], strides=2)
        self.layer4 = self.buildResBlock(512, layer_dims[3], strides=2)

        #自适应输出层
        self.avgpool = layers.GlobalAveragePooling2D()
        #全连接层
        self.fc = layers.Dense(num_classes)

    def call(self, inputs, training = None):
        x = self.stem(inputs)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        #经过avgpool => [b, 512]
        x = self.avgpool(x)
        #经过Dense => [b, 100]
        x = self.fc(x)
        return x

    def buildResBlock(self, filter_num, blocks, strides = 1):
        resBlocks = Sequential()
        resBlocks.add(BasicBlock(filter_num, strides))
        #后续的resBlock的strides都设置为1
        for _ in range(1, blocks):
            resBlocks.add(BasicBlock(filter_num))
        return resBlocks;

def ResNet18():
    return ResNet([2, 2, 2 ,2]);

def ResNet34():
    return ResNet([3, 4, 6, 3])


#加载CIFAR100数据集
#如果下载很慢,可以使用迅雷下载到本地,迅雷的链接也可以直接用官网URL:
#      https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
#下载好后,将cifar-100.python.tar.gz放到 .keras\datasets 目录下(我的环境是C:\Users\Administrator\.keras\datasets)
# 参考:https://blog.csdn.net/zy_like_study/article/details/104219259
(x_train,y_train), (x_test, y_test) = datasets.cifar100.load_data()
print("Train data shape:", x_train.shape)
print("Train label shape:", y_train.shape)
print("Test data shape:", x_test.shape)
print("Test label shape:", y_test.shape)

def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    return x,y

y_train = tf.squeeze(y_train, axis=1)
y_test = tf.squeeze(y_test, axis=1)

batch_size = 128
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(1000).map(preprocess).batch(batch_size)

test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(batch_size)

sample = next(iter(train_db))
print("Train data sample:", sample[0].shape, sample[1].shape, 
         tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))



def main():
    #创建ResNet
    resNet = ResNet18()
    resNet.build(input_shape=[None, 32, 32, 3])
    resNet.summary()
    
    #设置优化器
    optimizer = optimizers.Adam(learning_rate=1e-3)
    #进行训练
    num_epoches = 10
    for epoch in range(num_epoches):
        for step, (x,y) in enumerate(train_db):
            with tf.GradientTape() as tape:
                #[b, 32, 32, 3] => [b, 100]
                logits = resNet(x)
                #标签做one_hot encoding
                y_onehot = tf.one_hot(y, depth=100)
                #计算损失
                loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
                loss = tf.reduce_mean(loss)
            #计算梯度
            grads = tape.gradient(loss, resNet.trainable_variables)
            #更新参数
            optimizer.apply_gradients(zip(grads, resNet.trainable_variables))

            if (step % 100 == 0):
                print("Epoch[", epoch + 1, "/", num_epoches, "]: step - ", step, " loss:", float(loss))
        #进行验证
        total_samples = 0
        total_correct = 0
        for x,y in test_db:
            logits = resNet(x)
            prob = tf.nn.softmax(logits, axis=1)
            pred = tf.argmax(prob, axis=1)
            pred = tf.cast(pred, dtype=tf.int32)
            correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)
            correct = tf.reduce_sum(correct)

            total_samples += x.shape[0]
            total_correct += int(correct)

        #统计准确率
        acc = total_correct / total_samples
        print("Epoch[", epoch + 1, "/", num_epoches, "]: accuracy:", acc)

if __name__ == '__main__':
    main()

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/588689.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据库和缓存一致性问题

hello,各位小伙伴们大家好,我是颜书凌,下面给大家讲解一下数据库和缓存的一致性问题,话不多说 1、一致性介绍 一致性就是数据保持一致,在分布式系统中,可以理解为多个节点中数据的值是一致的。 强一致性…

2024年【G3锅炉水处理】试题及解析及G3锅炉水处理模拟考试题

题库来源:安全生产模拟考试一点通公众号小程序 2024年G3锅炉水处理试题及解析为正在备考G3锅炉水处理操作证的学员准备的理论考试专题,每个月更新的G3锅炉水处理模拟考试题祝您顺利通过G3锅炉水处理考试。 1、【多选题】在可逆反应中,下面哪…

Node.js -- express 框架

文章目录 1. express 使用2. 路由2.1 路由的使用2.2 获取请求报文参数2.3 获取路由参数2.4 路由参数练习 3. express 响应设置4. 中间件4.1 全局中间件4.2 路由中间件4.3 静态资源中间件 5. 获取请求体数据 body-parser6. 防盗链7. 路由模块化8. 模板引擎8.1 了解EJS8.2 列表渲…

面试二十四、继承多态

一、继承的本质和原理 组合(Composition): 组合是一种"有一个"的关系,表示一个类包含另一个类的对象作为其成员。这意味着一个类的对象包含另一个类的对象作为其一部分。组合关系通常表示强关联,被包含的对象…

【Week-Y7】使用自己的数据集训练YOLO-v8

文章目录 一、官方环境配置与测试1. 配置环境2. 用官方图片测试(图片下载失败)3. 用本地图片测试,检查配置的环境是否可用 二、使用自己的数据集进行训练测试1. 执行split_train_val.py文件2. 执行python .\voc_label.py文件3. 创建fruit.yam…

[Python基础知识]05函数和模块

一、函数的定义 格式:def 函数名(参数列表): 注: 函数代码块以 def 关键词开头,后接函数标识符名称和圆括号()。即使该函数不需要接收任何参数,也必须保留一对空的圆括号 函数形参不需要声明其类型&#x…

layui中禁用div标签等操作

为了实现点击表格行后触发事件 然后去触发后进行操作 页面流程操作设置规定 不可编辑直接添加属性 class"layui-disabled"如果在最大的 div 设置不可编辑 但是内部有些还是可以触发使用的 所以就重写一下 取到当前 div 下的 所有的子元素 然后在给所有的子元素…

闲话 ASP.NET Core 数据校验(二):FluentValidation 基本用法

前言 除了使用 ASP.NET Core 内置框架来校验数据,事实上,通过很多第三方框架校验数据,更具优势。 比如 FluentValidation,FluentValidation 是第三方的数据校验框架,具有许多优势,是开发人员首选的数据校验…

抢先体验:MacOS成功安装PHP8.4教程

根据官方消息,PHP 8.4将于2024年11月21日发布。它将通过三个 alpha 版本、三个 beta 版本和六个候选版本进行测试。 这次的重大更新将为PHP带来许多优化和强大的功能。我们很高兴能够引导您完成最有趣的更新升级,这些更改将使我们能够编写更好的代码并构…

解决React报错Encountered two children with the same key

当我们从map()方法返回的两个或两个以上的元素具有相同的key属性时,会产生"Encountered two children with the same key"错误。为了解决该错误,为每个元素的key属性提供独一无二的值,或者使用索引参数。 这里有个例子来展示错误是…

YOLOv8主要命令讲解

YOLOv8主要有三个常用命令,分别是:train(训练)、predict(预测)、export(转化模型格式),下面我将展开讲讲三个常用命令的常用参数与具体使用方法。 一、训练 通过自己标…

STM32单片机通过串口控制DDSM210 直驱伺服电机

1 电机介绍 官方资料:https://www.waveshare.net/wiki/DDSM210 DDSM210 直驱伺服电机是基于一体化开发理念,集外转子无刷电机、编码器、伺服驱动于一体的高可靠性永磁同步电动机,其结构紧凑,安装方便,运行稳定&#x…

react核心知识

1. 对 React 的理解、特性 React 是靠数据驱动视图改变的一种框架,它的核心驱动方法就是用其提供的 setState 方法设置 state 中的数据从而驱动存放在内存中的虚拟 DOM 树的更新 更新方法就是通过 React 的 Diff 算法比较旧虚拟 DOM 树和新虚拟 DOM 树之间的 Chan…

【PCL】教程 supervoxel_clustering执行超体聚类并可视化点云数据及其聚类结果

[done, 417.125 ms : 307200 points] Available dimensions: x y z rgba 源点云milk_cartoon_all_small_clorox.pcd > Loading point cloud... > Extracting supervoxels! Found 423 supervoxels > Getting supervoxel adjacency 这段代码主要是使用PCL(Po…

Linux进程——进程的创建(fork的原理)

前言:在上一篇文章中,我们已经会使用getpid/getppid函数来查看pid和ppid,本篇文章会介绍第二种查看进程的方法,以及如何创建子进程! 本篇主要内容: 查看进程的第二种方法创建子进程系统调用函数fork 在开始前&#xff…

【华为】路由综合实验(基础)

【华为】路由综合实验 实验需求拓扑配置AR1AR2AR3AR4AR5PC1PC2 查看通信OSPF邻居OSPF路由表 BGPBGP邻居BGP 路由表 配置文档 实验需求 ① 自行规划IP地址 ② 在区域1里面 启用OSPF ③ 在区域1和区域2 启用BGP,使AR4和AR3成为eBGP,AR4和AR5成为iBGP对等体…

buuctf-misc-22.神秘龙卷风1

22.神秘龙卷风1 题目:暴力破解-翻译Brainfuck计算机语言 根据提示是4位密码,直接破解密码即可 解压后发现是这样一个文档 我们尝试使用网站翻译这个 内容由“”、“.”、“>”三种符号组成,我刚开始认为这是一种密文,经过搜索…

thinkpad电脑文件隐藏了怎么恢复?教你几招

在使用ThinkPad电脑时,有时我们可能会发现一些文件或文件夹突然“消失”了,这通常是因为它们被隐藏了。本文将为您介绍几招恢复ThinkPad电脑上隐藏文件的方法,帮助您轻松找回丢失的文件。 图片来源于网络,如有侵权请告知 一、了解…

【实时数仓架构】方法论

笔者不是专业的实时数仓架构,这是笔者从其他人经验和网上资料整理而来,仅供参考。写此文章意义,加深对实时数仓理解。 一、实时数仓架构技术演进 1.1 四种架构演进 1)离线大数据架构 一种批处理离线数据分析架构,…

when to create a ViewRootImpl

when to create a ViewRootImpl when method setView is called: when method dispatchDetachedFromWindow is called:
最新文章