模板
# -*-coding:utf-8-*-
import pytorch_lightning as pl
from monai import transforms
import numpy as np
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
fig import KeysCollection
from monai.utils import set_determinismpl.seed_everything(42)
set_determinism(42)class Config(object):passclass ObserveShape(transforms.MapTransform):def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):super(ObserveShape, self).__init__(keys, allow_missing_keys)def __call__(self, data):d = dict(data)for key in self.keys:print(d[key].shape)# 输入是(X,Y,Z)return d# 适用于分割有重叠的部分
class ConvertLabeld(transforms.MapTransform):def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):super(ConvertLabeld, self).__init__(keys, allow_missing_keys)def __call__(self, data):d = dict(data)for key in self.keys:img = d[key]res = []# 将 tumor和pancreas合并成pancreasres.append(np.logical_or(img == 1, img == 2))res.append(img == 2) # tumor通道res = np.stack(res, axis=0)# res = np.concatenate(res, axis=0)res = res.astype(np.float)d[key] = resreturn dclass LitsDataSet(pl.LightningDataModule):def __init__(self, cfg=Config()):super(LitsDataSet, self).__init__()passdef prepare_data(self):_init()pass# 划分训练集,验证集,测试集以及定义数据预处理和增强,def setup(self, stage=None) -> None:self.split_dataset()_preprocess()passdef train_dataloader(self):passdef val_dataloader(self):passdef test_dataloader(self):pass# 定义训练集和测试集的transformer,包括读取数据,数据增强,像素体素归一化等等def get_preprocess(self):passdef get_init(self):passdef split_dataset(self):passclass Lung(pl.LightningModule):# 定义网络模型,损失函数类,metrics类以及后处理标签函数等def __init__(self, cfg=Config()):super(Lung, self).__init__()passdef configure_optimizers(self):passdef training_step(self, batch, batch_idx):passdef validation_step(self, batch, batch_idx):passdef test_step(self, batch, batch_idx):passdef training_epoch_end(self, outputs):passdef validation_epoch_end(self, outputs):passdef test_epoch_end(self, outputs):pass# training_epoch_end,valid_epoch_end,test_epoch_end共同步骤可写在此函数中def shared_epoch_end(self, outputs, loss_key):pass# training_step,valid_step,test_step共同步骤可写在此函数中def shared_step(self, y_hat, y):passdata = LitsDataSet()
model = Lung()early_stop = EarlyStopping()cfg = Config()
check_point = ModelCheckpoint()
trainer = pl.Trainer(progress_bar_refresh_rate=10,gpus=1,# auto_select_gpus=True, # 这个参数针对混合精度训练时,不能使用# auto_lr_find=True,auto_scale_batch_size=True,callbacks=[early_stop, check_point],precision=16, # 16为指定半精度训练,accumulate_grad_batches=4,num_sanity_val_steps=0,log_every_n_steps=10,auto_lr_find=True
)
trainer.fit(model, data)
基于MSD的pancreas数据集的分割例子:
# -*-coding:utf-8-*-
import os
import randomimport torch
from torch import nn, functional as F, optim
import monai
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from monai import transforms
ansforms import Compose
ansforms import LoadImaged, LoadImage
from monai.data import Dataset, SmartCacheDataset
from torch.utils.data import DataLoader, random_split
from glob import glob
import numpy as np
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
fig import KeysCollection
from torch.utils.data import random_split
from SwinUnet_3D import swinUnet_t_3D
from monai.losses import DiceLoss, DiceFocalLoss, DiceCELoss, FocalLoss
ics import DiceMetric, HausdorffDistanceMetric
from monai.inferers import sliding_window_inference
from monai.utils import set_determinism
from monai.data import decollate_batch, list_data_collate
from monaiworks.utils import one_hot
from einops import rearrange
from torchmetrics.functional import dice_score
from torchmetrics import IoU, Accuracy
ansforms import (Activations,Activationsd,AsDiscrete,AsDiscreted,Compose,Invertd,LoadImaged,MapTransform,NormalizeIntensityd,Orientationd,RandFlipd,RandScaleIntensityd,RandShiftIntensityd,RandSpatialCropd,Spacingd,EnsureChannelFirstd,EnsureTyped,EnsureType,ConvertToMultiChannelBasedOnBratsClassesd,SpatialPadd,ScaleIntensityRangePercentilesd,ScaleIntensityRanged,CropForegroundd,RandCropByPosNegLabeld
)pl.seed_everything(42)
set_determinism(42)class Config(object):data_path = r'D:CaiyiminDatasetMSDPancreas'FinalShape = [160, 160, 160]window_size = [5, 5, 5] # 针对siwnUnet3D而言的窗口大小,FinalShape[i]能被window_size[i]数整除in_channels = 1# 数据集原始尺寸(体素间距为1.0时)中位数为(411,411,240)# 体素间距为1时,z轴最小尺寸为127,最大为499ResamplePixDim = (2.0, 2.0, 1.0)HuMax = 50 + 350 / 2HuMin = 35 - 350 / 2low_percent = 0.5upper_percent = 99.5train_ratio, val_ratio, test_ratio = [0.8, 0.2, 0.0]BatchSize = 1NumWorkers = 0n_classes = 2 # 括pancreas和cancer这两个通道lr = 3e-5 # 学习率back_bone_name = 'SwinUnet'# back_bone_name = 'Unet3D'# back_bone_name = 'UnetR'# 滑动窗口推理时使用roi_size = FinalShapeslid_window_overlap = 0.5class ObserveShape(transforms.MapTransform):def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):super(ObserveShape, self).__init__(keys, allow_missing_keys)def __call__(self, data):d = dict(data)for key in self.keys:print(d[key].shape)# 输入是(X,Y,Z)return dclass ConvertLabeld(transforms.MapTransform):def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):super(ConvertLabeld, self).__init__(keys, allow_missing_keys)def __call__(self, data):d = dict(data)for key in self.keys:img = d[key]res = []# 将 tumor和pancreas合并成pancreasres.append(np.logical_or(img == 1, img == 2))res.append(img == 2) # tumor通道res = np.stack(res, axis=0)# res = np.concatenate(res, axis=0)res = res.astype(np.float)d[key] = resreturn dclass LitsDataSet(pl.LightningDataModule):def __init__(self, cfg=Config()):super(LitsDataSet, self).__init__()self.cfg = cfgself.data_path = cfg.ain_path = os.path.join(cfg.data_path, 'imagesTr')self.label_tr_path = os.path.join(cfg.data_path, 'labelsTr')st_path = os.path.join(cfg.data_path, 'imagesTs')ain_dict = []self.val_dict = []st_dict = []ain_set = Noneself.val_set = st_set = ain_process = Noneself.val_process = Nonedef prepare_data(self):train_x, train_y, test_x = _init()for x, y in zip(train_x, train_y):info = {'image': x, 'label': ain_dict.append(info)for x in test_x:info = {'image': st_dict.append(_preprocess()# 划分训练集,验证集,测试集以及定义数据预处理和增强,def setup(self, stage=None) -> None:self.split_dataset()ain_set = ain_dict, transformain_process)self.val_set = Dataset(self.val_dict, transform=self.val_st_set = st_dict, transform=self.val_process)def train_dataloader(self):cfg = self.cfgreturn ain_set, batch_size=cfg.BatchSize,num_workers=cfg.NumWorkers,collate_fn=list_data_collate)def val_dataloader(self):cfg = self.cfgreturn DataLoader(self.val_set, batch_size=cfg.BatchSize, num_workers=cfg.NumWorkers)def test_dataloader(self):cfg = self.cfgreturn st_set, batch_size=cfg.BatchSize, num_workers=cfg.NumWorkers)def get_preprocess(self):cfg = ain_process = Compose([LoadImaged(keys=['image', 'label']),EnsureChannelFirstd(keys=['image']),ConvertLabeld(keys='label'),Orientationd(keys=['image', 'label'], axcodes='RAS'),Spacingd(keys=['image', 'label'], pixdim=cfg.ResamplePixDim,mode=('bilinear', 'nearest')),ScaleIntensityRanged(keys='image', a_min=cfg.HuMin, a_max=cfg.HuMax,b_min=0.0, b_max=1.0, clip=True),# CropForegroundd(keys=['image', 'label'], source_key='image'),SpatialPadd(keys=['image', 'label'], spatial_size=cfg.FinalShape),RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label',spatial_size=cfg.FinalShape,pos=1, neg=1, num_samples=1, image_key='image', ),RandFlipd(keys=['image', 'label'], prob=0.5, spatial_axis=0),RandFlipd(keys=['image', 'label'], prob=0.5, spatial_axis=1),RandFlipd(keys=['image', 'label'], prob=0.5, spatial_axis=2),RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),EnsureTyped(keys=['image', 'label']),])self.val_process = Compose([LoadImaged(keys=['image', 'label']),EnsureChannelFirstd(keys=['image']),ConvertLabeld(keys='label'),Orientationd(keys=['image', 'label'], axcodes='RAS'),Spacingd(keys=['image', 'label'], pixdim=cfg.ResamplePixDim,mode=('bilinear', 'nearest')),ScaleIntensityRanged(keys='image', a_min=cfg.HuMin, a_max=cfg.HuMax,b_min=0.0, b_max=1.0, clip=True),# CropForegroundd(keys=['image', 'label'], source_key='image'),EnsureTyped(keys=['image', 'label']),])def get_init(self):train_x = sorted(glob(os.path.ain_path, '*.')))train_y = sorted(glob(os.path.join(self.label_tr_path, '*.')))test_x = sorted(glob(os.path.st_path, '*.')))return train_x, train_y, test_xdef split_dataset(self):cfg = self.cfgnum = ain_dict)train_num = int(num * ain_ratio)val_num = int(num * cfg.val_ratio)test_num = int(num * st_ratio)if train_num + val_num + test_num != num:remain = num - train_num - test_num - val_numval_num += ain_dict, self.val_dict, st_dict = random_ain_dict, [train_num, val_num, test_num])class Lung(pl.LightningModule):def __init__(self, cfg=Config()):super(Lung, self).__init__()self.cfg = cfgif cfg.back_bone_name == 'SwinUnet':self = swinUnet_t_3D(window_size=cfg.window_size,num_classes=cfg.n_classes,in_channel=cfg.in_channels, )else:from monaiworkss import UNETR, UNetif cfg.back_bone_name == 'UnetR':self = UNETR(in_channels=cfg.in_channels,out_channels=cfg.n_classes,img_size=cfg.FinalShape)else:self = UNet(spatial_dims=3, in_channels=1,out_channels=cfg.n_classes,channels=(32, 64, 128, 256, 512),strides=(2, 2, 2, 2))self.loss_func = DiceLoss(smooth_nr=0, smooth_dr=1e-5,squared_pred=False,sigmoid=ics = DiceMetric(include_background=True,reduction='mean_batch')self.post_pred = Compose([EnsureType(), Activations(sigmoid=True),AsDiscrete(threshold_values=True)])def configure_optimizers(self):cfg = self.cfgopt = optim.AdamW(params=self.parameters(), lr=cfg.lr, eps=1e-7,weight_decay=1e-5)# lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(# opt, T_0=5, T_mult=1, )# return {'optimizer': opt, 'lr_scheduler': lr_scheduler, 'monitor': 'valid_loss'}return optdef training_step(self, batch, batch_idx):x = batch['image']y = batch['label']# y_hat = sliding_window_inference(x, roi_size=cfg.FinalShape,# sw_batch_size=cfg.BatchSize,# predictor=self,# overlap=cfg.slid_window_overlap)y_hat = self(x)loss, dice = self.shared_step(y_hat=y_hat, y=y)p_dice, t_dice = dice[0], dice[1]self.log('train_loss', loss, prog_bar=True)self.log('train_pancreas_dice', p_dice, prog_bar=True)self.log('train_tumor_dice', t_dice, prog_bar=True)return {'loss': loss}def validation_step(self, batch, batch_idx):cfg = self.cfgx = batch['image']y = batch['label']y_hat = sliding_window_inference(x, roi_size=cfg.FinalShape, sw_batch_size=cfg.BatchSize, predictor=self,overlap=cfg.slid_window_overlap)loss, dice = self.shared_step(y_hat=y_hat, y=y)p_dice, t_dice = dice[0], dice[1]self.log('valid_loss', loss, prog_bar=True)self.log('valid_pancreas_dice', p_dice, prog_bar=True)self.log('valid_tumor_dice', t_dice, prog_bar=True)return {'loss': loss}def test_step(self, batch, batch_idx):cfg = self.cfgx = batch['image']y = batch['label']y_hat = sliding_window_inference(x, roi_size=cfg.FinalShape, sw_batch_size=1, predictor=self,overlap=cfg.slid_window_overlap)loss, dice = self.shared_step(y_hat=y_hat, y=y)p_dice, t_dice = dice[0], dice[1]self.log('test_loss', loss, prog_bar=True)self.log('test_pancreas_dice', p_dice, prog_bar=True)self.log('test_tumor_dice', t_dice, prog_bar=True)return {'loss': loss}def training_epoch_end(self, outputs):losses, dice = self.shared_epoch_end(outputs, 'loss')p_dice, t_dice = dice[0], dice[1]self.log('train_mean_loss', losses, prog_bar=True)self.log('train_mean_pancreas_dice', p_dice, prog_bar=True)self.log('train_mean_tumor_dice', t_dice, prog_bar=True)def validation_epoch_end(self, outputs):losses, dice = self.shared_epoch_end(outputs, 'loss')p_dice, t_dice = dice[0], dice[1]self.log('valid_mean_loss', losses, prog_bar=True)self.log('valid_mean_pancreas_dice', p_dice, prog_bar=True)self.log('valid_mean_tumor_dice', t_dice, prog_bar=True)def test_epoch_end(self, outputs):losses, dice = self.shared_epoch_end(outputs, 'loss')p_dice, t_dice = dice[0], dice[1]self.log('valid_mean_loss', losses, prog_bar=True)self.log('valid_mean_pancreas_dice', p_dice, prog_bar=True)self.log('valid_mean_tumor_dice', t_dice, prog_bar=True)def shared_epoch_end(self, outputs, loss_key):losses = []for output in outputs:# loss = output['loss'].detach().cpu().numpy()loss = output[loss_key].item()losses.append(loss)losses = np.array(losses)losses = np.mean(losses)dice = ics.aggregate()set()dice = dice.detach().cpu().numpy()return losses, dicedef shared_step(self, y_hat, y):loss = self.loss_func(y_hat, y)y_hat = [self.post_pred(it) for it in decollate_batch(y_hat)]y = decollate_batch(y)dice = ics(y_hat, y)dice = torch.nan_to_num(dice)loss = torch.nan_to_num(loss)dice = an(dice, dim=0)return loss, dicedata = LitsDataSet()
model = Lung()early_stop = EarlyStopping(monitor='valid_mean_loss',patience=10,
)cfg = Config()
check_point = ModelCheckpoint(dirpath=f'./trained_models/{cfg.back_bone_name}',save_last=False,save_top_k=2, monitor='valid_mean_loss', verbose=True,filename='{epoch}-{valid_loss:.2f}-{valid_mean_dice:.2f}')
trainer = pl.Trainer(progress_bar_refresh_rate=10,max_epochs=400,min_epochs=30,gpus=1,# auto_select_gpus=True, # 这个参数针对混合精度训练时,不能使用# auto_lr_find=True,auto_scale_batch_size=True,logger=TensorBoardLogger(save_dir=f'./logs', name=f'{cfg.back_bone_name}'),callbacks=[early_stop, check_point],precision=16,accumulate_grad_batches=4,num_sanity_val_steps=0,log_every_n_steps=10,auto_lr_find=True
)
trainer.fit(model, data)
适用于医学图像分割的SwinUnet3D源码暂时还不能公开,请将back_bone换成Unetr或者Unet3D等等
本文发布于:2024-02-01 07:30:44,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/170674384634900.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |