自定义数据集

自定义数据集至少要重写__init____len____getitem__三个方法,init中定义数据路径,最好能把数据读进内存;len中定义有多少个训练样本;getitem尽量只从内存读,避免读磁盘,若数据太大,可以维持一个固定大小的内存池,偶尔从磁盘读。

若getitem包含运算,则设置num_workers>0,并行读取
torch.backends.cudnn.benchmark = True 开启可以加速卷积神经网络运算。

Dataloader示例:

from torch.utils.data import Dataloader
from tqdm import tqdm

dataloader = Dataloader(dataset, batchsize=8, shuffle=True)

for i in range(epoch):
    with tqdm(total=len(dataloader)) as t:
        for idx, (batch_x, batch_y) in enumerate(dataloader):
            # pre_y = model(batch_x)
            # loss= loss_fn(pre_y, batch_y)
            t.set_description(desc="Epoch %i:"%i)
            t.set_postfix(steps=idx, loss=loss.data.item())
            t.update(1)
            # optimizer.zero_grad()
            # loss.backward()
            # optimizer.step()

其他注意

  1. 代码文件中标注

    __author__ = 'kly'
  2. 配置参数使用argparse,并在运行时打印配置信息,以便日志中保存:

    import argparse
        def make_parser():
            parser = argparse.ArgumentParser("train parameter")
            parser.add_argument("-w", "--weight_path", default="./weights", type=str, help="model save path")
            parser.add_argument("--log_path", default="./logs", type=str, help="tensorboard log save path")
            parser.add_argument("--class_nums", default=45, type=int, help="how many classes do you have")
            parser.add_argument("--epoch_nums", default=5, type=int, help="train epoch")
            parser.add_argument("--batch_size", default=16, type=int, help="batch size")
            parser.add_argument("--lr", default=0.0001, type=float, help="init lr")
            parser.add_argument("--tsize", default=256, type=int, help="train img size = tsize * tsize")
            return parser
    
    # main函数中
    args = make_parser().parse_args()
    print("----------------------------------------------------------------")
    for key in args.__dict__:
     print(key, end=' = ')
     print(args.__dict__[key])
    print("----------------------------------------------------------------")
  3. 对于有后续完善空间的部分要标注#TODO

  4. 适当写警告和报错语句

    try:
        assert np.isfinite(score)
    except AssertionError as e:
        raise ValueError('score is NaN or infinite') from e
    
    raise ValueError('the function is not supported now')
  5. 固定所有随机数种子。

  6. 每次训练要将训练log输出到文件,保存在train.log中,并且log中需打印出本次实验的参数配置。

    import sys
    import os
    
    class Logger(object):
        def __init__(self, filename='default.log', stream=sys.stdout):
            self.terminal = stream
            self.log = open(filename, 'w')
    
        def write(self, message):
            self.terminal.write(message)
            self.log.write(message)
    
        def flush(self):
            pass
    
    sys.stdout = Logger(os.path.join(args.log_path,time.strftime('train21-%m-%d-%H%M.log',time.localtime(time.time()))), sys.stdout)
    # sys.stderr = Logger('a.log_file', sys.stderr)
  7. 打印训练过程时,记得加上时间

    import time
    print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
  8. 使用Tensorboard进行可视化

    pytorch使用Tensorboard示例

    from torch.utils import tensorboard
    
    writer = tensorboard.SummaryWriter('../logs/')
    print('tensorboard initialized')
    
    init_image = torch.zeros((1,3,224,224), device=device)
    writer.add_graph(model, init_image)
    
    writer.add_scalar('train_loss', loss, epoch+1)
    writer.add_scalar('val_acc', acc, epoch+1)
    writer.add_scalar('lr', lr, epoch+1)
  9. 一些需要反复使用的调试语句,可以使用logger输出

    给logger设置是告诉它要记录哪些级别的日志,给handler设是告诉它要输出哪些级别的日志,相当于进行了两次过滤。这样的好处在于,当我们有多个日志去向时,比如既保存到文件,又输出到控制台,就可以分别给他们设置不同的级别;logger 的级别是先过滤的,所以被 logger 过滤的日志 handler 也是无法记录的,这样就可以只改 logger 的级别而影响所有输出。两者结合可以更方便地管理日志记录的级别。

    logging.FileHandler -> 文件输出

    logging.StreamHandler() # 控制台输出

    logging.handlers.RotatingFileHandler -> 按照大小自动分割日志文件,一旦达到指定的大小重新生成文件

    logging.handlers.TimedRotatingFileHandler -> 按照时间自动分割日志文件

    logger.debug('debug级别,一般用来打印一些调试信息,级别最低')

    logger.info('info级别,一般用来打印一些正常的操作信息')

    logger.warning('waring级别,一般用来打印警告信息')

    logger.error('error级别,一般用来打印一些错误信息')

    logger.critical('critical级别,一般用来打印一些致命的错误信息,等级最高')

    import logging
    from logging import handlers
    
    logger = logging.getLogger('train')
    logger.setLevel(level=logging.DEBUG)  # 设置打印级别
    formatter = logging.Formatter('%(asctime)s: %(message)s')  # 设置打印格式
    
    stream_handler = logging.StreamHandler()  # 控制台输出
    stream_handler.setLevel(logging.DEBUG)
    stream_handler.setFormatter(formatter)
    
    file_handler = logging.FileHandler('train1.log', encoding='utf-8')
    file_handler.setLevel(level=logging.INFO)
    file_handler.setFormatter(formatter)
    
    logger.addHandler(stream_handler)
    logger.addHandler(file_handler)
    
    logger.info('info级别,一般用来打印一些正常的操作信息')
    
    time_rotating_file_handler = handlers.TimedRotatingFileHandler(filename='rotating_test.log', when='D',encoding='utf-8')
    time_rotating_file_handler.setLevel(logging.INFO)
    time_rotating_file_handler.setFormatter(formatter)
    logger.addHandler(time_rotating_file_handler)
  10. 代码需要注意包含:断点续训、保存模型、加载模型进行测试这几部分。

     # pytorch
    
     # 加载模型
     net = resnet34()
     # load pretrain weights
     # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
     model_weight_path = "./resnet34-pre.pth"
     assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
     net.load_state_dict(torch.load(model_weight_path, map_location=device))
    
     # 保存最优模型
     if val_accurate > best_acc:
         best_acc = val_accurate
         torch.save(net.state_dict(), save_path)
  11. 项目开源时要有requirements.txt文件,用于记录所有依赖包及其精确的版本号。

    主要的用法如下

     pip freeze > requirements.txt  # 生成requirements.txt
     pip install -r requirements.txt # 从requirements.txt安装依赖

    文件中支持的写法

     -r base.txt  # base.txt下面的所有包
     pypinyin==0.12.0 # 指定版本(最日常的写法)
     django-querycount>=0.5.0 # 大于某个版本
     django-debug-toolbar>=1.3.1,<=1.3.3 # 版本范围
     ipython # 默认(存在不替换,不存在安装最新版)

标签: python

添加新评论