import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
from gpt import GPTLanguageModel
from transformers import GPT2Tokenizer
from tqdm import tqdm
from utils.draw_plots import Draw
import pynvml as nvml
import os
import time
import wandb
from utils import draw_stuff
from torch.utils.data import DataLoader
from utils.load_data import WikipediaDataset
import signal
from colorama import Fore
from queue import Queue
class Train():
def __init__(self, **kwargs,):
torch.manual_seed(1137)
model= GPTLanguageModel()
nvml.nvmlInit()
os.system("cls" if os.name == 'nt' else 'clear')
draw_stuff.draw()
self.enc= GPT2Tokenizer.from_pretrained('gpt2')
self.device='cuda' if torch.cuda.is_available else 'cpu'
self.m= model.to(self.device)
self.block_size= 256 if not 'block_size' in kwargs else kwargs['block_size']
self.batch_size= 100 if not 'batch_size' in kwargs else kwargs['batch_size']
gpu_idx= 0 if not 'gpu_index' in kwargs else kwargs['gpu_index']
self.handle= nvml.nvmlDeviceGetHandleByIndex(gpu_idx)
self.temp_thres= 85 if not 'temp_threshold' in kwargs else kwargs['temp_threshold']
self.plot= Draw()
def load_data(self,split, split_per):
dataset= WikipediaDataset(split= split, split_per= split_per)
data_title= DataLoader(dataset, pin_memory= True)
return data_title
@torch.no_grad()
def estimate_loss(self, eval_iters)->torch.tensor:
out = {}
self.m.eval()
for split in ['train', 'val']:
es_progress= tqdm(total=eval_iters, ncols=100)
es_progress.colour='red'
print("Estimating loss\n")
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
es_progress.update(1)
X, Y = self.get_batch(split)
logits, loss = self.m(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
self.m.train()
return out
def get_batch(self, split, curr_epoch=None):
data = self.load_data(split, split_per=0.8)
data_len= len(data)
if split =='train':
print(f'loading data for training epoch:{curr_epoch}')
load_data=tqdm(total= data_len, ncols=100)
load_data.colour= 'yellow'
elif split == 'val':
print(f'loading data for validation iteration:{curr_epoch}')
load_data= tqdm(total=data_len, ncols=100)
load_data.colour('magenta')
for idx, data in enumerate(data):
load_data.update(1)
doc= data
doc_txt= doc['text'][0]
title= doc['title'][0]
encoded_tok= self.enc.encode(doc_txt, add_special_tokens= True)
doc_txt_enc= torch.tensor(encoded_tok, dtype=torch.long).squeeze()
size= self.block_size
if len(encoded_tok) - self.block_size <= 0 and len(encoded_tok)-20 >=3:
size=20
ix = torch.randint(len(encoded_tok) - size , (self.batch_size,))
x = torch.stack([doc_txt_enc[i:i+size] for i in ix])
y = torch.stack([doc_txt_enc[i:i+size+1] for i in ix])
load_data.set_description(f"Data Loader Step")
yield x, y
def display_train_params(self,train_id, device, **kwargs):
print(f"**NOTE: GPU temperature threshold has been set to {self.temp_thres}°C, when the threshold is reached the training process will halt for a set period.**\n")
print(f"INFO {torch.cuda.memory_allocated(device)} Bytes of memory is allocated for this task\n")
print(f"""Training parameters:
Device: {nvml.nvmlDeviceGetName(self.handle)}
Halt Temperature threshold:{self.temp_thres}
Trainable Parameters: {sum(p.numel() for p in self.m.parameters())/1e6, 'M parameters'}
Total Epochs: {kwargs['epochs']}
Evaluation Iterations: {kwargs['eval_iters']}
Evaluation Interval: {kwargs['eval_interval']}
Initial Learning rate: {kwargs['learning_rate']}
Learning Rate Schduler: Cosine Annealing
Total Memory allocated: {torch.cuda.memory_allocated(device)}B
\n""")
print(f"** Training Started | Train ID : {train_id}**\n")
def train(self, device, train_id, is_save=True, **kwargs):
wandb.init('Training',
project='gpt-model',
config={
"initial_learning_rate":3e-4,
"architecture":"transformer",
"dataset": "Wikipedia general documents"
}
)
eval_interval= 500 if not 'ei' in kwargs else kwargs['ei']
learning_rate= 3e-4 if 'learning_rate' not in kwargs else kwargs['learning_rate']
eval_iters= 300 if 'eval_iter' not in kwargs else kwargs['eval_iter']
max_iters= 10000 if not 'epochs' in kwargs else kwargs['epochs']
os.mkdir('results') if not 'results' in os.listdir('.') else None
os.mkdir(f'results/{train_id}') if train_id not in os.listdir('results') else None
os.mkdir(f'results/{train_id}/checkpoints') if 'checkpoints' not in os.listdir(f'results/{train_id}') else None
os.mkdir(f'results/{train_id}/checkpoints/plots') if 'plots' not in os.listdir(f'results/{train_id}/checkpoints') else None
torch.cuda.empty_cache()
optimizer = torch.optim.AdamW(self.m.parameters(),
lr=learning_rate)
schduler= CosineAnnealingLR(optimizer=optimizer, T_max=max_iters)
self.display_train_params(train_id=train_id,
device=device,
epochs=max_iters,
eval_interval=eval_interval,
eval_iters=eval_iters,
learning_rate=learning_rate
)
epoch_progress_bar= tqdm(total=eval_interval,
ncols=100
)
epoch_progress_bar.colour='cyan'
counter= 0
cont_params={'train_loss':[],
'l_r':[]}
ckpt_params={'train_loss':[],
'val_loss':[],
'l_r':[]}
it_cnt=0
for iter in range(0, max_iters):
epoch_progress_bar.update(1)
curr_temperature= nvml.nvmlDeviceGetTemperature(self.handle,
nvml.NVML_TEMPERATURE_GPU
)
if curr_temperature >= self.temp_thres:
print(f"\n Set temperature threshold of {self.temp_thres}°C reached halting for {4} seconds ")
time.sleep(3)
print("\n Resuming Training ")
if iter % eval_interval == 0 or iter == max_iters-1:
checkpoint_save_path= f'results/{train_id}/checkpoints/checkpoint-{counter}-epoch{iter}.pth'
losses =self.estimate_loss(eval_iters)
ckpt_params['l_r'].append(schduler.get_last_lr()[0])
ckpt_params['train_loss'].append(losses['train'])
ckpt_params['val_loss'].append(losses['val'])
wandb.log({"eval epoch":iter, "validation loss":losses['val']})
plot_save_path= f'results/{train_id}/checkpoints/plots/checkpoint-{counter}'
if iter ==0:
pass
else:
self.plot.draw_line(mode='loss',
train_loss= cont_params['train_loss'],
epochs=[it for it in range(it_cnt, it_cnt+eval_interval)],
save_path=f'{plot_save_path}-epoch_vs_loss.png'
)
self.plot.draw_line(mode='lrvloss',
l_r= cont_params['l_r'],
train_loss=cont_params['train_loss'],
save_path=f'{plot_save_path}-learning_rate_vs_loss.png'
)
self.plot.draw_line(mode='lrve',
l_r= cont_params['l_r'],
epochs=[it for it in range(it_cnt, it_cnt+eval_interval)],
save_path=f'{plot_save_path}-learning_rate_vs_epoch.png'
)
cont_params['l_r'].clear()
cont_params['train_loss'].clear()
torch.save(self.m.state_dict(), checkpoint_save_path)
it_cnt+=iter
print(f"step {iter+1}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
epoch_progress_bar.close()
epoch_progress_bar = tqdm(total=eval_interval, ncols=100)
epoch_progress_bar.colour='cyan'
counter+=1
print(f'loading data for epoch{iter}')
for xb, yb in self.get_batch(split='train'):
logits, loss = self.m(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
schduler.step()
curr_lr= schduler.get_last_lr()[0]
train_loss= loss.item()
cont_params['l_r'].append(curr_lr)
cont_params['train_loss'].append(train_loss)
wandb.log({'epoch':iter, 'train loss':train_loss, 'learning rate':learning_rate})
epoch_progress_bar.set_description(f"Epoch: {iter}/{max_iters} |current LR- {curr_lr}")
os.mkdir(f'results/{train_id}/final_plots') if 'final_plots' not in os.listdir(f'results/{train_id}') else None
self.plot.draw_line(mode='loss',
train_loss= ckpt_params['train_loss'],
val_loss=ckpt_params['val_loss'],
epochs=[it for it in range(0,max_iters, 500)],
save_path=f'results/{train_id}/final_plots/plot-epoch_vs_loss.png',
plot_type='ckpt'
)
self.plot.draw_line(mode='lrvloss',
l_r= ckpt_params['l_r'],
train_loss=ckpt_params['train_loss'],
val_loss= ckpt_params['val_loss'],
save_path=f'results/{train_id}/final_plots/plot-learning_rate_vs_loss.png',
plot_type='ckpt'
)
self.plot.draw_line(mode='lrve',
l_r= ckpt_params['l_r'],
epochs=[it for it in range(0,max_iters, 500)],
save_path=f'results/{train_id}/final_plots/plot-learning_rate_vs_loss.png',
plot_type='ckpt'
)
nvml.nvmlShutdown()
wandb.finish()