【论文10】复现代码tips

avatar
作者
筋斗云
阅读量:0

一、准备工作

1.创建一个虚拟环境

conda create --name drgcnn38 python=3.8.18

2.激活虚拟环境

conda activate drgcnn38

注意事项

在Pycharm中终端(terminal)显示PS而不是虚拟环境base

问题如下所示

解决方法:shell路径改成cmd.exe

重启终端显示虚拟环境

3.安装torch

conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 cpuonly -c pytorch

安装一系列包

注意事项

Pycharm远程连接Linux服务器实现代码同步

1.工具-->部署-->配置

2.选择SFTP远程连接,路径填与服务器要同步的路径地址

二、代码学习

各部分的作用

  • eye_pre_process:视网膜眼底图像预处理模块。
  • Encoder:编码器训练模块。
  • modules:包含模型结构、损失函数和学习率降低策略。
  • utils:包含一些常用函数和评估指标。
  • BFFN:双眼特征融合网络训练模块。
  • CAM:类别注意力模块。

eye_pre_process

copy.py

# 创建一个ArgumentParser对象,用于处理命令行参数   parser = argparse.ArgumentParser()      # 添加一个命令行参数 '--image-folder',类型为字符串,默认值为 'D:/cv_paper/lesson/Dataset/ceshi'   # 这个参数用于指定输入图像的文件夹路径   parser.add_argument('--image-folder', type=str, default=r'D:/cv_paper/lesson/Dataset/ceshi')      # 添加一个命令行参数 '--output-folder',类型为字符串,默认值为 'D:\cv_paper\lesson/Dataset/ceshi_output'   # 注意:这里路径中的反斜杠在不同的操作系统中可能需要特别注意,Python字符串中推荐使用原始字符串(r前缀)来避免转义字符的问题   # 这个参数用于指定输出结果的文件夹路径   parser.add_argument('--output-folder', type=str, default=r'D:\cv_paper\lesson/Dataset/ceshi_output')      # 添加一个命令行参数 '--crop-size',类型为整数,默认值为512   # 这个参数用于指定图像裁剪的大小   parser.add_argument('--crop-size', type=int, default=512, help='crop size of image')      # 添加一个命令行参数 '-n' 或 '--num-processes',类型为整数,默认值为8   # 这个参数用于指定处理任务时要使用的进程数   # '-n' 是 '--num-processes' 的简写形式,帮助信息说明了该参数的作用   parser.add_argument('-n', '--num-processes', type=int, default=8, help='number of processes to use')
# 转换一个包含多个任务的列表,每个任务由文件名、目标路径和裁剪大小组成   # 对于jobs列表中的每个任务(索引为j),它首先检查是否已经处理了100个任务(作为进度指示),然后调用convert函数来执行实际的图像转换。 def convert_list(i, jobs):       for j, job in enumerate(jobs):           # 每处理100个任务打印一次进度           if j % 100 == 0:               print(f'worker{i} has finished {j} tasks.')           # 解包任务元组并调用convert函数           convert(*job)      # 转换单个图像文件,包括模糊处理、裁剪和保存   def convert(fname, tgt_path, crop_size):       img = Image.open(fname)  # 打开图像文件          blurred = img.filter(ImageFilter.BLUR)  # 应用模糊滤镜       ba = np.array(blurred)  # 将图像转换为NumPy数组       h, w, _ = ba.shape  # 获取图像的高度、宽度和通道数          # 尝试根据图像的亮度分布来识别前景区域       if w > 1.2 * h:           # 计算左右两侧的最大亮度值           left_max = ba[:, :w // 32, :].max(axis=(0, 1)).astype(int)           right_max = ba[:, -w // 32:, :].max(axis=(0, 1)).astype(int)           max_bg = np.maximum(left_max, right_max)              foreground = (ba > max_bg + 10).astype(np.uint8)  # 识别前景区域           bbox = Image.fromarray(foreground).getbbox()  # 获取前景区域的最小边界框              # 如果边界框太小或不存在,则打印消息并可能设置为None           if bbox is None:               print(f'No bounding box found for {fname} (???)')           else:               left, upper, right, lower = bbox               if right - left < 0.8 * h or lower - upper < 0.8 * h:                   print(f'Bounding box too small for {fname}')                   bbox = None       else:           bbox = None  # 如果图像已经是合适的宽高比,则不尝试识别前景          # 如果未找到有效的边界框,则使用正方形边界框       if bbox is None:           bbox = square_bbox(img)          # 使用边界框裁剪图像,并调整大小       cropped = img.crop(bbox)       cropped = cropped.resize([crop_size, crop_size], Image.ANTIALIAS)  # 注意:ANTIALIAS可能是个拼写错误,应该是ANTIALIASIS       save(cropped, tgt_path)  # 保存图像      # 返回一个正方形裁剪框的边界   def square_bbox(img):       w, h = img.size       left = max((w - h) // 2, 0)       upper = 0       right = min(w - (w - h) // 2, w)       lower = h       return (left, upper, right, lower)      # 保存PIL图像到文件   def save(img, fname):       img.save(fname, quality=100, subsampling=0)  # 注意:subsampling参数可能不是所有格式都支持      # 假设的main函数,用于组织整个流程(注意:这里只是一个示例)   def main():       # 示例任务列表,每个任务是一个(文件名, 目标路径, 裁剪大小)元组       jobs = [           ('input1.jpg', 'output1_resized.jpg', 256),           ('input2.jpg', 'output2_resized.jpg', 256),           # ... 更多任务       ]              # 假设有一个工作者ID为1       convert_list(1, jobs)      if __name__ == "__main__":       main()  

Encoder

main.py

# 定义主函数入口   def main():       # 解析配置参数       args = parse_configuration()       # 加载配置文件       cfg = load_config(args.config)       # 获取配置中保存的路径       save_path = cfg.config_base.config_save_path       # 如果保存路径不存在,则创建该路径       if not os.path.exists(save_path):           os.makedirs(save_path)       # 将配置文件复制到保存路径       copy_config(args.config, cfg.config_base.config_save_path)       # 执行工作函数       worker(cfg)      # 定义工作函数,负责训练、验证和测试模型   def worker(cfg):       # 根据配置生成模型       model = generate_model(cfg)       # 计算模型总参数数量       total_param = 0       for param in model.parameters():           total_param += param.numel()       print("Parameter: %.2fM" % (total_param / 1e6))  # 打印模型参数数量(单位:百万)       # 根据配置生成训练、验证和测试数据集       train_dataset, test_dataset, val_dataset = generate_dataset(cfg)       # 初始化性能评估器       estimator = PerformanceEvaluator(cfg.config_train.config_criterion, cfg.config_data.config_num_classes)       # 执行训练过程       train(           cfg=cfg,           model=model,           train_dataset=train_dataset,           val_dataset=val_dataset,           estimator=estimator,       )          # 测试最佳验证模型性能       print('This is the performance of the best validation model:')       checkpoint = os.path.join(cfg.config_base.config_save_path, 'best_validation_weights.pt')       cfg.config_train.config_checkpoint = checkpoint  # 设置检查点路径为最佳验证模型       model = generate_model(cfg)  # 重新生成模型以加载权重       evaluate(cfg, model, test_dataset, estimator)  # 评估模型性能          # 测试最终模型性能       print('This is the performance of the final model:')       checkpoint = os.path.join(cfg.config_base.config_save_path, 'final_weights.pt')       cfg.config_train.config_checkpoint = checkpoint  # 设置检查点路径为最终模型       model = generate_model(cfg)  # 重新生成模型以加载权重       evaluate(cfg, model, test_dataset, estimator)  # 评估模型性能      # 如果此脚本作为主程序运行,则调用main函数   if __name__ == '__main__':       main()

Encoder_predict.py

进行模型的训练,具体来说,它定义了一个训练循环&#x

广告一刻

为您即时展示最新活动产品广告消息,让您随时掌握产品活动新动态!