医学图像处理 利用pytorch实现的可用于反传的Radon变换和逆变换

news/2024/7/21 6:36:46 标签: 图像处理, pytorch, 人工智能

医学图像处理 利用pytorch实现的可用于反传的Radon变换和逆变换

  • 前言
  • 代码
  • 实现思路
  • 实验结果

前言

Computed Tomography(CT,计算机断层成像)技术作为如今医学中重要的辅助诊断手段,也是医学图像研究的重要主题。如今,随着深度学习对CT重建、CT分割的研究逐渐深入,论文开始越来越多的利用CT的每一个环节,来扩充Feature或构造损失函数。

但是每到这个时候,一个问题就出现了,如果我要构造损失函数,我势必要保证这个运算中梯度不会断掉,否则起不到优化效果。而Radon变换目前好像没有人直接用其当作损失函数的一部分,很奇怪,故在此实现pytorch版本的Radon变换,已经验证可以反传(但是反传的对不对不敢保证,只保证能计算出反传的值)。希望能帮到需要的同学。

参考了这两篇博文,在此十分感谢这位前辈。
Python实现离散Radon变换
Python实现逆Radon变换——直接反投影和滤波反投影

代码

from typing import Optional
import numpy as np
import matplotlib.pyplot as plt
import math
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import SimpleITK as sitk


#两种滤波器的实现
def RLFilter(N, d):
    filterRL = np.zeros((N,))
    for i in range(N):
        filterRL[i] = - 1.0 / np.power((i - N / 2) * np.pi * d, 2.0)
        if np.mod(i - N / 2, 2) == 0:
            filterRL[i] = 0
    filterRL[int(N/2)] = 1 / (4 * np.power(d, 2.0))
    return filterRL

def SLFilter(N, d):
    filterSL = np.zeros((N,))
    for i in range(N):
        filterSL[i] = - 2 / (np.pi**2.0 * d**2.0 * (4 * (i - N / 2)**2.0 - 1))
    return filterSL

def IRandonTransform(image:'th.Tensor|np.array', steps:Optional[int]=None):
    '''
    Inverse Radon Transform(with Filter, FBP)

    Parameters:
        image: (B, C, W, H)
    '''
    # 定义用于存储重建后的图像的数组
    channels = image.shape[-1]
    B, C, W, H = image.shape
    if steps == None:
        steps = channels
    origin = th.zeros((B, C, steps, channels, channels), dtype=th.float32)
    filter_kernal = th.tensor(SLFilter(channels, 1)).unsqueeze(0).unsqueeze(0).float()
    Filter = nn.Conv1d(C, C, (channels), padding='same',bias=False)
    Filter.weight = nn.Parameter(filter_kernal) 

    for i in range(steps):
    	# 传入的图像中的每一列都对应于一个角度的投影值
    	# 这里用的图像是上篇博文里得到的Radon变换后的图像裁剪后得到的
        projectionValue = image[:, :, :, i]
        projectionValue = Filter(projectionValue)
        # 这里利用维度扩展和重复投影值数组来模拟反向均匀回抹过程
        projectionValueExpandDim = projectionValue.unsqueeze(2)
        projectionValueRepeat = projectionValueExpandDim.repeat((1, 1, channels, 1))
        origin[:,:, i] = rotate(projectionValueRepeat, (i / steps) * math.pi)
    #各个投影角度的投影值已经都保存在origin数组中,只需要将它们相加即可
    iradon = th.sum(origin, axis=2)
    return iradon


def rotate(image:th.Tensor, angle):
    '''
    Rotate the image in any angles(include negtive).
    angle should be pi = 180
    '''
    B= image.shape[0]
    transform_matrix = th.tensor([
            [math.cos(angle),math.sin(-angle),0],
            [math.sin(angle),math.cos(angle),0]], dtype=th.float32).unsqueeze(0).repeat(B,1,1)
    grid = F.affine_grid(transform_matrix, # 旋转变换矩阵
                            image.shape).float()	# 变换后的tensor的shape(与输入tensor相同)
    rotation = F.grid_sample(image, # 输入tensor,shape为[B,C,W,H]
                            grid, # 上一步输出的gird,shape为[B,C,W,H]
                            mode='nearest') # 一些图像填充方法,这里我用的是最近邻
    return rotation


def DiscreteRadonTransform(image:'th.Tensor|np.array', steps:Optional[int]=None):
    '''
    Radon Transform

    Parameters:
        image : (B, C, W, H)
    '''
    channels = image.shape[-1] # img_size
    B, C, W, H = image.shape
    res = th.zeros((B, channels, channels), dtype=th.float32)
    if steps == None:
        steps = channels
    for s in range(steps):
        angle = (s / steps) * math.pi
        rotation = rotate(image, -angle)
        res[:, :,s] = th.sum(rotation, dim=2)
    return res.unsqueeze(1)
    
if __name__ == '__main__':

    origin = sitk.ReadImage('/hy-tmp/data/LIDC/CT/0001.nii.gz')
    t_origin = sitk.GetArrayFromImage(origin)
    t_origin = th.tensor(t_origin)
    t_origin = t_origin[40].unsqueeze(0).unsqueeze(0)
    a = nn.Parameter(th.ones_like(t_origin))
    t_origin = t_origin * a
    ret = DiscreteRadonTransform(t_origin) # (B, 1, H, W)
    b = th.ones_like(ret)
    lf = nn.MSELoss()
    loss = lf(b, ret)
    loss.backward() # pytorch不会报错,并能返回grad
    rec = IRandonTransform(ret)
    ret = ret.squeeze(0)
    rec = rec.squeeze(0)
    plt.imsave('test.png', (t_origin.squeeze(0).squeeze(0)).data.numpy(), cmap='gray')
    plt.imsave('test2.png', (ret.squeeze(0)).data.numpy(), cmap='gray')
    plt.imsave('test3.png', (rec.squeeze(0)).data.numpy(), cmap='gray')

实现思路

这份代码实际上是参考前文提到的前辈的代码修改而来,具体而言就是把各种numpy实现的地方修改为pytorch的对应实现,其中pytorch没有对应的API来实现矩阵的Rotate,因此还参考了网上其它人实现的手写旋转的pytorch版本。并将其写作Rotate函数,在其它任务中也可以调用,这里需要注意,调用时,矩阵需要是方阵,否则会出现旋转后偏移中心的问题。

实验结果

原始图像:
请添加图片描述
Radon变换的结果:
请添加图片描述
重建结果:
请添加图片描述
这些图像可以下载下来,自己试试。


http://www.niftyadmin.cn/n/5469101.html

相关文章

【大模型】大模型 CPU 推理之 llama.cpp

【大模型】大模型 CPU 推理之 llama.cpp llama.cpp安装llama.cppMemory/Disk RequirementsQuantization测试推理下载模型测试 参考 llama.cpp 描述 The main goal of llama.cpp is to enable LLM inference with minimal setup and state-of-the-art performance on a wide var…

鸿蒙OS开发实例:【组件化模式】

组件化一直是移动端比较流行的开发方式,有着编译运行快,业务逻辑分明,任务划分清晰等优点,针对Android端的组件化;与Android端的组件化相比,HarmonyOS的组件化可以说实现起来就颇费一番周折,因为…

简单使用bootstrap-datepicker日期插件

目录 下载datepicker 方式一: 方式二: 下载依赖 下载bootstarp.js 下载jquery 使用示例 日期选择 单独选择年 单独选择月 单独选择日 设置截止日期 设置默认日期 总结 下载datepicker 方式一: 下载地址 GitHub - uxsolution…

Golang | Leetcode Golang题解之第7题整数反转

题目&#xff1a; 题解&#xff1a; func reverse(x int) (rev int) {for x ! 0 {if rev < math.MinInt32/10 || rev > math.MaxInt32/10 {return 0}digit : x % 10x / 10rev rev*10 digit}return }

学习鸿蒙基础(12)

目录 一、网络json-server配置 &#xff08;1&#xff09;然后输入&#xff1a; &#xff08;2&#xff09;显示下载成功。但是输入json-server -v的时候。报错。 &#xff08;3&#xff09;此时卸载默认的json-server &#xff08;4&#xff09;安装和nodejs匹配版本的js…

2024水会|全国水科技大会第一版日程正式公布

中华环保联合会、福州大学、上海大学在四川省成都市联合举办“2024全国水科技大会暨技术装备成果展览会”。 大会主题&#xff1a;加快形成新质生产力 增强水业发展新动能 大会亮点&#xff1a;邀请部委、四川省、各市领导&#xff0c;6位院士&#xff0c;100余位行业专家&a…

PTA题解 --- 天梯赛的赛场安排(C语言)

今天是PTA题库解法讲解的第八天&#xff0c;今天我们要讲解天梯赛的赛场安排&#xff0c;题目如下&#xff1a; 解题思路&#xff1a; 这个问题的关键在于高效地为参赛学校的队员分配赛场&#xff0c;同时满足给定的条件。我们可以通过以下步骤解决这个问题&#xff1a; 存储每…

解决Vue中仓库持久化的问题,不借助插件用原生JS实现仓库持久化。了解仓库的插件机制、监听的时机

1、演示 前言&#xff1a;目前Vue有两种仓库&#xff0c;一种是Vuex&#xff0c;一种是Pinia&#xff0c;懂得都懂&#xff0c;这里就不详细介绍这两者的区别了 2、什么是持久化 仓库里面的数据是需要跨越页面周期的&#xff0c;当页面刷新之后数据还在&#xff0c;在默认情况下…