How to use Densenet121 in monai

Joined
Feb 16, 2024
Messages
1
Reaction score
0
I have CT images (512x512x84) and the labels are those CT images but segmented after I used ITK_snap to "Paint" an specific tissue (Ocular tissue) , the image files are stored in train_images and the labels in Train_Labels, then i create a dictionary to which i create a dataset and apply the MONAI transforms

, the issue is when i try to use densenet 121 or any other model, after the transforms I get that the tensor size is [1,1,512,512,1] , when i introduce this datasets into the model i get the error that input image size (T:128,H:128,W:1) cant be smaller than kernel size (KT:2,KH:2,KW:2) so i tried UNET which it kind of works because i have to use squeeze due to UNET changes de value in the last position of the tensor from [1,1,512,512,1] to [1,1,512,512,2] and i get error but using squeeze i turn the 2 into a 1 , im using confusion matrix as a metric in the model.eval() section , i create a list where i store the values from the matrix after each EPOCH and i get this TN metatensor(1073321133, device='cuda:0')
FP metatensor(24578667, device='cuda:0')
FN metatensor(3014885, device='cuda:0')
TP metatensor(90115, device='cuda:0')
metatensor(1101004800, device='cuda:0')(FP+FN+TP+TN), which is weird because for what i understand it should be 512x512x84 = 22020096 pixel value

So ill be very grateful if someone can help me, and im sorry if my explanation isn´t good, im new in this. The MONAI verion is 1.3.0 and im using vscode

import logging
import os
import monai
import sys
from glob import glob
import numpy as np
import torch

from torch import squeeze

from torchmetrics.classification import Recall
from torchmetrics.classification import Accuracy
from torchmetrics.classification import BinaryRecall
from torchmetrics.classification import BinaryPrecision
from torchmetrics.classification import BinaryAUROC
from torchmetrics.classification import BinarySpecificity
from torchmetrics.classification import Dice
from torchmetrics.classification import BinaryConfusionMatrix

import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from monai.utils import first
#from torchmetrics.classification import BinaryF1Score
from torchmetrics.functional.classification import binary_f1_score
from monai.data import decollate_batch, DataLoader,Dataset

from monai.metrics import ROCAUCMetric
from torchmetrics.classification import BinaryF1Score
from monai.transforms import SqueezeDimd, Activations,ToTensord, AsDiscrete, Compose, LoadImaged, RandRotate90d, Resized, ScaleIntensityd,EnsureChannelFirstd,RandFlipd,RandZoomd,RandGaussianNoised,SqueezeDimd

from monai.networks.nets import DenseNet121,UNet,RegUNet,VNet
from monai.networks.blocks import UnetResBlock,ResBlock


def main():
monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
data_dir="C:/Users/ACER/Desktop/POSGRADO VIU/Prácticas/Datasets/archive/files/aneurismamonai"

train_images = sorted(glob(os.path.join(data_dir,"TrainData","*.dcm")))
train_labels = sorted(glob(os.path.join(data_dir,"TrainLabel","*.dcm")))

val_images = sorted(glob(os.path.join(data_dir,"ValData","*.dcm")))

val_labels = sorted(glob(os.path.join(data_dir,"ValLabel","*.dcm")))

train_files=[{"image":image,"label":label} for image,label in zip(train_images,train_labels)]
val_files=[{"image":image,"label":label} for image,label in zip(val_images,val_labels)]

print (len(train_files))




train_transforms = Compose(
[
LoadImaged(keys=["image","label"]),
EnsureChannelFirstd(keys=["image","label"]),
ScaleIntensityd(keys=["image"]),
RandRotate90d(keys=["image"], prob=0.5, spatial_axes=[0, 1]),


#RandZoomd(keys=["image","label"],min_zoom=0.9, max_zoom=1.1, prob=0.5, keep_size=True),
#Resized(keys=["image","label"], spatial_size=(512, 512)),





]
)
val_transforms = Compose(
[
LoadImaged(keys=["image","label"]),
EnsureChannelFirstd(keys=["image","label"]),
ScaleIntensityd(keys=["image","label"]),

#Resized(keys=["image","label"], spatial_size=(512, 512, 2)),

#ToTensord(keys=["image","label"]),
]
)



train_ds=Dataset(data=train_files,transform=train_transforms)
train_loader = DataLoader(train_ds,batch_size=1,num_workers=1,pin_memory=torch.cuda.is_available())#batch se refiere a la canitad de imagenes a usar en las iteraciones
first_batch = next(iter(train_loader))
print (len(train_loader))
dataset_size = len(train_loader.dataset) # Number of samples in the dataset
batch_size = train_loader.batch_size # Size of each batch

print(f"Dataset size: {dataset_size}")
print(f"Batch size: {batch_size}")



primer_dato = first_batch["image"]
primer_dato=primer_dato.squeeze(-1)
print("XD",primer_dato.shape)


#


val_ds=Dataset(data=val_files,transform=val_transforms)
val_loader =DataLoader(val_ds,batch_size=1)




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=1,pretrained=False).to(device)#in channels es detemrinado si la imagen es de escala de grises o de color, al ser una imagen
#model = ResBlock(spatial_dims=3, in_channels=1,norm="BATCH",kernel_size=1, act=('RELU', {'inplace': True})).to(device)
model = UNet(spatial_dims=3,in_channels=1,out_channels=1,channels=(16,32),strides=(2,2,2)).to(device)


loss_function = torch.nn.BCEWithLogitsLoss()
#loss_function = torch.nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
#auc_metric = ROCAUCMetric()



val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
c00=[]
c01=[]
c10=[]
c11=[]

numepoch=100


for epoch in range(numepoch):
print("-" * 10)
print(f"epoch {epoch + 1}/{numepoch}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs,labels = (batch_data["image"].to(device),batch_data["label"].to(device))


print("inputs",inputs.shape)
optimizer.zero_grad()
outputs = model(inputs)#BUSCAR LOS PIXELES, convertir a pixeles
outputs = outputs[:,:,:,:,1]# unet, sirve
outputs = outputs.unsqueeze(4)#unet,sirve
print("outputs",outputs.shape)

loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
labels=labels.long()
accuracy=Accuracy(task="Binary").to(device)#que tan cercanos son los valores a una medida
#f1=BinaryF1Score().to(device)
metricac=accuracy(outputs,labels)
speci=BinarySpecificity().to(device)#que resultados no son negativos
metric4=speci(outputs,labels)

print("Accuracy:",metricac)#epoch 1 = 0.7937 epoch3=0.7938#accuracy metric

print("Sensitivity",metric4)



#se puede calcular las matricas , precisionmetric pytorch , auc area bajo la curva roc
#evaluacion atraves de la perdida
epoch_len = len(train_ds) // train_loader.batch_size
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
# writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)

print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

if (epoch + 1) % val_interval == 0:
model.eval()
num_correct = 0.0
metric_count = 0
with torch.no_grad():

for val_data in val_loader:
val_images, val_labels = val_data["image"].to(device), val_data["label"].to(device)

val_outputs = model(val_images)
val_outputs = val_outputs[:,:,:,:,1]
val_outputs = val_outputs.unsqueeze(4)
#outputsval=val_outputs.data.resize_(1,1,128,128,84)
value = torch.eq(val_outputs.argmax(dim=1), val_labels.argmax(dim=1))
metric_count += len(value)
num_correct += value.sum().item()
bcm = BinaryConfusionMatrix().to(device)
len2=len(val_data)
val_labels2=(val_labels>=0.5).float()
metric1eval=bcm(val_outputs,val_labels2)
print("Matriz",metric1eval)
c1=metric1eval[0][0]
c2=metric1eval[0][1]
c3=metric1eval[1][0]
c4=metric1eval[1][1]
c00.append(c1)
c01.append(c2)
c10.append(c3)
c11.append(c4)
print("longitud",len2)
print(c1,c2)
print(c1+c2)
recall=BinaryRecall().to(device)
metricac=accuracy(outputs,labels)
print("Accuracy:",metricac)
metric2=recall(outputs,labels)#de los verdaderos postivos, cuantos son prececidos como positivos
precision=BinaryPrecision().to(device)
metric3=precision(outputs,labels)
dice=Dice().to(device)#Similitud entre dos conjuntos de datos
metric5=dice(outputs,labels)
print("Recall:",metric2)
print("Precision",metric3)
print("Dice",metric5)

#calculo de metricas la imagen puedo o no tener ojos, matriz de confusion
metric = num_correct / metric_count
metric_values.append(metric)

if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), "best_metric_model_classification3d_array.pth")
print("saved new best metric model")

print(f"Current epoch: {epoch+1} current accuracy: {metric:.4f} ")
print(f"Best accuracy: {best_metric:.4f} at epoch {best_metric_epoch}")
#writer.add_scalar("val_accuracy", metric, epoch + 1)


#print("Matriz",metric1eval)
totalsum00= sum([num for num in c00])
totalsum01= sum([num2 for num2 in c01])
totalsum10= sum([num3 for num3 in c10])
totalsum11= sum([num4 for num4 in c11])
print("TN",totalsum00)
print("FP",totalsum01)
print("FN",totalsum10)
print("TP",totalsum11)
#total1=sum(totalsum00,totalsum01)
#total2=sum(totalsum10,totalsum11)
result=torch.add(totalsum11,totalsum00)
result2=torch.add(totalsum01,totalsum10)
res=torch.add(result,result2)
print(res)





print(f"Training completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
valx=[i + 1 for i in range (len(epoch_loss_values))]
valy= epoch_loss_values
plt.xlabel("Epoch")
plt.plot(valx,valy)
plt.show()


if name == "main":
main()
 

Ask a Question

Want to reply to this thread or ask your own question?

You'll need to choose a username for the site, which only take a couple of moments. After that, you can post your question and our members will help you out.

Ask a Question

Members online

Forum statistics

Threads
473,769
Messages
2,569,582
Members
45,058
Latest member
QQXCharlot

Latest Threads

Top