我的神经网络类python代码编程习惯
自定义数据集
自定义数据集至少要重写__init__
、__len__
、__getitem__
三个方法,init中定义数据路径,最好能把数据读进内存;len中定义有多少个训练样本;getitem尽量只从内存读,避免读磁盘,若数据太大,可以维持一个固定大小的内存池,偶尔从磁盘读。
若getitem包含运算,则设置num_workers>0,并行读取
torch.backends.cudnn.benchmark = True
开启可以加速卷积神经网络运算。
Dataloader示例:
from torch.utils.data import Dataloader
from tqdm import tqdm
dataloader = Dataloader(dataset, batchsize=8, shuffle=True)
for i in range(epoch):
with tqdm(total=len(dataloader)) as t:
for idx, (batch_x, batch_y) in enumerate(dataloader):
# pre_y = model(batch_x)
# loss= loss_fn(pre_y, batch_y)
t.set_description(desc="Epoch %i:"%i)
t.set_postfix(steps=idx, loss=loss.data.item())
t.update(1)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
其他注意
-
代码文件中标注
__author__ = 'kly'
-
配置参数使用argparse,并在运行时打印配置信息,以便日志中保存:
import argparse def make_parser(): parser = argparse.ArgumentParser("train parameter") parser.add_argument("-w", "--weight_path", default="./weights", type=str, help="model save path") parser.add_argument("--log_path", default="./logs", type=str, help="tensorboard log save path") parser.add_argument("--class_nums", default=45, type=int, help="how many classes do you have") parser.add_argument("--epoch_nums", default=5, type=int, help="train epoch") parser.add_argument("--batch_size", default=16, type=int, help="batch size") parser.add_argument("--lr", default=0.0001, type=float, help="init lr") parser.add_argument("--tsize", default=256, type=int, help="train img size = tsize * tsize") return parser # main函数中 args = make_parser().parse_args() print("----------------------------------------------------------------") for key in args.__dict__: print(key, end=' = ') print(args.__dict__[key]) print("----------------------------------------------------------------")
-
对于有后续完善空间的部分要标注
#TODO
-
适当写警告和报错语句
try: assert np.isfinite(score) except AssertionError as e: raise ValueError('score is NaN or infinite') from e raise ValueError('the function is not supported now')
-
固定所有随机数种子。
-
每次训练要将训练log输出到文件,保存在
train.log
中,并且log中需打印出本次实验的参数配置。import sys import os class Logger(object): def __init__(self, filename='default.log', stream=sys.stdout): self.terminal = stream self.log = open(filename, 'w') def write(self, message): self.terminal.write(message) self.log.write(message) def flush(self): pass sys.stdout = Logger(os.path.join(args.log_path,time.strftime('train21-%m-%d-%H%M.log',time.localtime(time.time()))), sys.stdout) # sys.stderr = Logger('a.log_file', sys.stderr)
-
打印训练过程时,记得加上时间
import time print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
-
使用Tensorboard进行可视化
pytorch使用Tensorboard示例
from torch.utils import tensorboard writer = tensorboard.SummaryWriter('../logs/') print('tensorboard initialized') init_image = torch.zeros((1,3,224,224), device=device) writer.add_graph(model, init_image) writer.add_scalar('train_loss', loss, epoch+1) writer.add_scalar('val_acc', acc, epoch+1) writer.add_scalar('lr', lr, epoch+1)
-
一些需要反复使用的调试语句,可以使用logger输出
给logger设置是告诉它要记录哪些级别的日志,给handler设是告诉它要输出哪些级别的日志,相当于进行了两次过滤。这样的好处在于,当我们有多个日志去向时,比如既保存到文件,又输出到控制台,就可以分别给他们设置不同的级别;logger 的级别是先过滤的,所以被 logger 过滤的日志 handler 也是无法记录的,这样就可以只改 logger 的级别而影响所有输出。两者结合可以更方便地管理日志记录的级别。
logging.FileHandler -> 文件输出
logging.StreamHandler() # 控制台输出
logging.handlers.RotatingFileHandler -> 按照大小自动分割日志文件,一旦达到指定的大小重新生成文件
logging.handlers.TimedRotatingFileHandler -> 按照时间自动分割日志文件
logger.debug('debug级别,一般用来打印一些调试信息,级别最低')
logger.info('info级别,一般用来打印一些正常的操作信息')
logger.warning('waring级别,一般用来打印警告信息')
logger.error('error级别,一般用来打印一些错误信息')
logger.critical('critical级别,一般用来打印一些致命的错误信息,等级最高')
import logging from logging import handlers logger = logging.getLogger('train') logger.setLevel(level=logging.DEBUG) # 设置打印级别 formatter = logging.Formatter('%(asctime)s: %(message)s') # 设置打印格式 stream_handler = logging.StreamHandler() # 控制台输出 stream_handler.setLevel(logging.DEBUG) stream_handler.setFormatter(formatter) file_handler = logging.FileHandler('train1.log', encoding='utf-8') file_handler.setLevel(level=logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(stream_handler) logger.addHandler(file_handler) logger.info('info级别,一般用来打印一些正常的操作信息') time_rotating_file_handler = handlers.TimedRotatingFileHandler(filename='rotating_test.log', when='D',encoding='utf-8') time_rotating_file_handler.setLevel(logging.INFO) time_rotating_file_handler.setFormatter(formatter) logger.addHandler(time_rotating_file_handler)
-
代码需要注意包含:断点续训、保存模型、加载模型进行测试这几部分。
# pytorch # 加载模型 net = resnet34() # load pretrain weights # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth model_weight_path = "./resnet34-pre.pth" assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path) net.load_state_dict(torch.load(model_weight_path, map_location=device)) # 保存最优模型 if val_accurate > best_acc: best_acc = val_accurate torch.save(net.state_dict(), save_path)
-
项目开源时要有requirements.txt文件,用于记录所有依赖包及其精确的版本号。
主要的用法如下
pip freeze > requirements.txt # 生成requirements.txt pip install -r requirements.txt # 从requirements.txt安装依赖
文件中支持的写法
-r base.txt # base.txt下面的所有包 pypinyin==0.12.0 # 指定版本(最日常的写法) django-querycount>=0.5.0 # 大于某个版本 django-debug-toolbar>=1.3.1,<=1.3.3 # 版本范围 ipython # 默认(存在不替换,不存在安装最新版)