How to not load an insanely big dataset in less than 50 hrs

Joined
Sep 2, 2023
Messages
2
Reaction score
0
Hello guys,

I am working on a new architecture of GPT created with pytorch and also tested the model once where it failed miserably due to the lack of training data , I first trained the model on a 400 token text document and now I am stepping the dataset to a 100B token dataset and I have very limited experience working on a dataset this big and I am working on this project by myself .....

So ... coming to the point .. I successfully cleaned the dataset and also was successful in loading the dataset and got through 1 epoch in 50 hours technically as I was not patient to sit through 50 hrs for just one epoch , the problem is my current code trains the model on the entire dataset for every epoch and that is hella inefficient , so I have come up with a plan to split the dataset into 20000 documents for each epoch .... it is all good until the validation part of the training loop where the code computes the training and the validation loss for a set of 500 epochs , so the function runs the model in eval mode (where torch excludes the dropout layer) and runs the model on both the training data that the model has currently trained on and a test set .... the problem is I have to keep track of all the indices of all the data the model has passed through and load them again, I think there is an efficient approach for this , if there is please help me figure this out . I will attach the code which loads the data and the code for the training loop .....

Thanks
:)
 
Joined
Sep 2, 2023
Messages
2
Reaction score
0
This is the code for the training loop ....

Code:
import torch 
from torch.optim.lr_scheduler import CosineAnnealingLR
from gpt import GPTLanguageModel
from transformers import GPT2Tokenizer
from tqdm import tqdm
from utils.draw_plots import Draw
import pynvml as nvml
import os
import time
import wandb
from utils import draw_stuff
from torch.utils.data import DataLoader
from utils.load_data import WikipediaDataset
import signal 
from colorama import Fore
from queue import Queue
class Train():
    def __init__(self, **kwargs,):
        torch.manual_seed(1137)
        model= GPTLanguageModel()
        nvml.nvmlInit()
        os.system("cls" if os.name == 'nt' else 'clear')
        draw_stuff.draw()
        self.enc= GPT2Tokenizer.from_pretrained('gpt2')
        self.device='cuda' if torch.cuda.is_available else 'cpu'
        self.m= model.to(self.device)
        self.block_size= 256 if not 'block_size' in kwargs else kwargs['block_size']
        self.batch_size= 100 if not 'batch_size' in kwargs else kwargs['batch_size']
        gpu_idx= 0 if not 'gpu_index' in kwargs else kwargs['gpu_index']
        self.handle= nvml.nvmlDeviceGetHandleByIndex(gpu_idx)
        self.temp_thres= 85 if not 'temp_threshold' in kwargs else kwargs['temp_threshold']
        self.plot= Draw()
        
    def load_data(self,split, split_per):
        dataset= WikipediaDataset(split= split, split_per= split_per)
        data_title= DataLoader(dataset, pin_memory= True)
        return data_title
    @torch.no_grad()
    def estimate_loss(self, eval_iters)->torch.tensor:
        out = {}
        self.m.eval()
        for split in ['train', 'val']:
            es_progress= tqdm(total=eval_iters, ncols=100)
            es_progress.colour='red'
            print("Estimating loss\n")
            losses = torch.zeros(eval_iters)
            for k in range(eval_iters):
                es_progress.update(1)
                X, Y = self.get_batch(split)
                logits, loss = self.m(X, Y)
                losses[k] = loss.item()
            out[split] = losses.mean()
        self.m.train()
        return out
    
    def get_batch(self, split, curr_epoch=None):
        
        data = self.load_data(split, split_per=0.8)
        data_len= len(data)
        
        if split =='train':
            print(f'loading data for training epoch:{curr_epoch}')
            load_data=tqdm(total= data_len, ncols=100)
            load_data.colour= 'yellow'
        elif split == 'val':
            print(f'loading data for validation iteration:{curr_epoch}')
            load_data= tqdm(total=data_len, ncols=100)
            load_data.colour('magenta')
    
        for idx, data in enumerate(data):
            load_data.update(1)
            doc= data
            doc_txt= doc['text'][0]
            title= doc['title'][0]
            encoded_tok= self.enc.encode(doc_txt, add_special_tokens= True)
            doc_txt_enc= torch.tensor(encoded_tok, dtype=torch.long).squeeze()
            size= self.block_size
            if len(encoded_tok) - self.block_size <= 0 and len(encoded_tok)-20 >=3:
                size=20
            ix = torch.randint(len(encoded_tok) - size , (self.batch_size,))
    
            x = torch.stack([doc_txt_enc[i:i+size] for i in ix])
            y = torch.stack([doc_txt_enc[i:i+size+1] for i in ix])    
            load_data.set_description(f"Data Loader Step")
            yield x, y
    
    def display_train_params(self,train_id, device, **kwargs):
        print(f"**NOTE: GPU temperature threshold has been set to {self.temp_thres}°C, when the threshold is reached the training process will halt for a set period.**\n")
        print(f"INFO {torch.cuda.memory_allocated(device)} Bytes of memory is allocated for this task\n")
        print(f"""Training parameters:
                  Device: {nvml.nvmlDeviceGetName(self.handle)}
                  Halt Temperature threshold:{self.temp_thres}
                  Trainable Parameters: {sum(p.numel() for p in self.m.parameters())/1e6, 'M parameters'}
                  Total Epochs: {kwargs['epochs']}
                  Evaluation Iterations: {kwargs['eval_iters']}
                  Evaluation Interval: {kwargs['eval_interval']}
                  Initial Learning rate: {kwargs['learning_rate']}
                  Learning Rate Schduler: Cosine Annealing
                  Total Memory allocated: {torch.cuda.memory_allocated(device)}B
                  \n""")
        
        print(f"** Training Started | Train ID : {train_id}**\n")
    
    def train(self, device, train_id, is_save=True, **kwargs):
        
        wandb.init('Training',  
                    project='gpt-model',
                    config={
                        "initial_learning_rate":3e-4,
                        "architecture":"transformer",
                        "dataset": "Wikipedia general documents"
                            }
                    )
        
        eval_interval= 500 if not 'ei' in kwargs else kwargs['ei']
        learning_rate= 3e-4 if 'learning_rate' not in kwargs else kwargs['learning_rate']
        eval_iters= 300 if 'eval_iter' not in kwargs else kwargs['eval_iter']
        max_iters= 10000 if not 'epochs' in kwargs else kwargs['epochs']
        os.mkdir('results') if not 'results' in os.listdir('.') else None
        os.mkdir(f'results/{train_id}') if train_id not in os.listdir('results') else None
        os.mkdir(f'results/{train_id}/checkpoints') if 'checkpoints' not in os.listdir(f'results/{train_id}') else None
        os.mkdir(f'results/{train_id}/checkpoints/plots') if 'plots' not in os.listdir(f'results/{train_id}/checkpoints') else None
        torch.cuda.empty_cache()
        optimizer = torch.optim.AdamW(self.m.parameters(), 
                                      lr=learning_rate)
        schduler= CosineAnnealingLR(optimizer=optimizer, T_max=max_iters)
        self.display_train_params(train_id=train_id, 
                                  device=device, 
                                  epochs=max_iters, 
                                  eval_interval=eval_interval,
                                  eval_iters=eval_iters,
                                  learning_rate=learning_rate
                                  )
        
        epoch_progress_bar= tqdm(total=eval_interval, 
                                 ncols=100
                                )
        
        epoch_progress_bar.colour='cyan'
        counter= 0
        cont_params={'train_loss':[], 
                     'l_r':[]}
        ckpt_params={'train_loss':[], 
                     'val_loss':[],
                     'l_r':[]}
        it_cnt=0
        for iter in range(0, max_iters):
            epoch_progress_bar.update(1)
            curr_temperature= nvml.nvmlDeviceGetTemperature(self.handle, 
                                                            nvml.NVML_TEMPERATURE_GPU
                                                           )
    
            if curr_temperature >= self.temp_thres:
                print(f"\n Set temperature threshold of {self.temp_thres}°C reached halting for {4} seconds ")
                time.sleep(3)
                print("\n Resuming Training ")
            
            if iter % eval_interval == 0 or iter == max_iters-1:
                
                checkpoint_save_path= f'results/{train_id}/checkpoints/checkpoint-{counter}-epoch{iter}.pth'
                losses =self.estimate_loss(eval_iters)
                ckpt_params['l_r'].append(schduler.get_last_lr()[0])
                ckpt_params['train_loss'].append(losses['train'])
                ckpt_params['val_loss'].append(losses['val'])
                wandb.log({"eval epoch":iter, "validation loss":losses['val']})
                plot_save_path= f'results/{train_id}/checkpoints/plots/checkpoint-{counter}'
    
                if iter ==0:
                    pass
                else:
                    
                    self.plot.draw_line(mode='loss', 
                                        train_loss= cont_params['train_loss'],  
                                        epochs=[it for it in range(it_cnt, it_cnt+eval_interval)],
                                        save_path=f'{plot_save_path}-epoch_vs_loss.png'
                                        )
                    self.plot.draw_line(mode='lrvloss',
                                        l_r= cont_params['l_r'],
                                        train_loss=cont_params['train_loss'], 
                                        save_path=f'{plot_save_path}-learning_rate_vs_loss.png'
                                        )
                    
                    self.plot.draw_line(mode='lrve',
                                        l_r= cont_params['l_r'], 
                                        epochs=[it for it in range(it_cnt, it_cnt+eval_interval)],
                                        save_path=f'{plot_save_path}-learning_rate_vs_epoch.png'
                                        )
                    
                    
                    cont_params['l_r'].clear()
                    cont_params['train_loss'].clear()
                    torch.save(self.m.state_dict(), checkpoint_save_path)
                    it_cnt+=iter
    
                print(f"step {iter+1}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
                
                epoch_progress_bar.close() 
                epoch_progress_bar = tqdm(total=eval_interval, ncols=100)
                epoch_progress_bar.colour='cyan'
                counter+=1
            print(f'loading data for epoch{iter}')

            for xb, yb in self.get_batch(split='train'):
                logits, loss = self.m(xb, yb)
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()
                schduler.step()
                curr_lr= schduler.get_last_lr()[0]
                train_loss= loss.item()
            cont_params['l_r'].append(curr_lr)
            cont_params['train_loss'].append(train_loss)
            wandb.log({'epoch':iter, 'train loss':train_loss, 'learning rate':learning_rate})            
            epoch_progress_bar.set_description(f"Epoch: {iter}/{max_iters} |current LR- {curr_lr}")
        
        os.mkdir(f'results/{train_id}/final_plots') if 'final_plots' not in os.listdir(f'results/{train_id}') else None
        self.plot.draw_line(mode='loss', 
                                    train_loss= ckpt_params['train_loss'], 
                                    val_loss=ckpt_params['val_loss'], 
                                    epochs=[it for it in range(0,max_iters, 500)],
                                    save_path=f'results/{train_id}/final_plots/plot-epoch_vs_loss.png',
                                    plot_type='ckpt'
                                    )
        self.plot.draw_line(mode='lrvloss',
                            l_r= ckpt_params['l_r'],
                            train_loss=ckpt_params['train_loss'], 
                            val_loss= ckpt_params['val_loss'],
                            save_path=f'results/{train_id}/final_plots/plot-learning_rate_vs_loss.png',
                            plot_type='ckpt'
                            )
        
        self.plot.draw_line(mode='lrve',
                            l_r= ckpt_params['l_r'], 
                            epochs=[it for it in range(0,max_iters, 500)],
                            save_path=f'results/{train_id}/final_plots/plot-learning_rate_vs_loss.png',
                            plot_type='ckpt'
                            )
        nvml.nvmlShutdown()
        wandb.finish()

some parts of the code is quite tacky, as I wanted to fix this problem before fixing other bugs ........

This is the code for loading the dataset
Code:
from torch.utils.data import Dataset
from datasets import load_dataset
from tqdm import tqdm
import sys
import re
class WikipediaDataset(Dataset):
    def __init__(self, dataset=None, split_per=0.8,split='train' , mode= 'local') -> None:
        super().__init__()
        data_code= ['wikipedia', '20220301.en']
        self.split= split
        self.split_per= split_per
        if dataset is not None:
            data_code= dataset.split('-')
            data= load_dataset(*data_code) 
        else :
            data= load_dataset('wikipedia', '20220301.en')['train']
        self.dataset= self.count_skip(data)
        if  mode =='local': 
            self.print_info(name= data_code[0],
                        code= data_code[0],
                        length=len(self.dataset),
                        split_percentage=self.split_per)
        
    def __len__(self)->int:
        if self.split == 'train':
            return int(len(self.dataset)*self.split_per)
        elif self.split == 'val':
            return int(len(self.dataset)- int(len(self.dataset)*self.split_per))
    def __getitem__(self, index)->dict:
        if self.split== 'train':
            sample= self.dataset[index]
        elif self.split == 'val':
            sample= self.dataset[int(len(self.dataset)*self.split_per+index)]        
        sample_text= sample['text']
        sample_title= sample['title']
        
        return {'title':sample_title, 'text':sample_text}
    
    def clean_data(self,data)->str:
        '''
        Cleans data
        '''
        data_wn= data.replace('\n', '')
        data_wn_a= re.sub(r'\n+', '', data_wn)
        data_wn_s= re.sub(r'\s+', ' ', data_wn_a)
        return data_wn_s
    
    def print_info(self, **kwargs):
        """
        Prints the dataset details just give the specifications as kwargs
        Function accepts :
        name: name of  the dataset
        code: dataset code
        language: language of the dataset
        type: type of the dataset e.g: audio, text , text-audio ,etc....
        length: length of the dataset
        train_data_len: length of the training data
        val_data_len: length of validation data
        """
        accepted_args= ['name', 'code', 'language', 'type', 'length','split_percentage']
        for arg in kwargs:
            if arg in accepted_args:
                if '_' in arg:
                    print(f"Dataset Split percentage: {kwargs[arg]*100}%")
                else:                    
                    suffix= arg[0].capitalize()+arg[1:]
                    print(f"Dataset {suffix}: {kwargs[arg]}")
        cal_len= lambda data_len, split_per, cal : data_len - data_len*split_per if cal == 'val' else data_len*split_per
        print(f"Training Data Length {cal_len(kwargs['length'],kwargs['split_percentage'],cal='train')}")
        print(f"Validation Data Length {cal_len(kwargs['length'],kwargs['split_percentage'],cal='val')}")
        cnf= input("Do you want to continue with is dataset to training [Y]es  [n]o => ") ## have to implement change dataset details 
        cnf=cnf.lower()
        if cnf == 'y' or cnf == 'yes':
            print(f'Proceeding...')
        elif cnf== 'n' or cnf == "no":
            print("Exiting...")
            sys.exit()
        else: print(f'argument {arg} is invalid. Use the arguments listed below \n {accepted_args}, or refer the documentation.')
        
    def count_skip(self,dataset)-> list:
        data_length= len(dataset)
        progerss_bar= tqdm(total=data_length, 
                           ncols=100, 
                           dynamic_ncols=True)
        progerss_bar.colour='magenta'
        data_rem= []
        counter=0
        progerss_bar.set_description("Cleaning Dataset")
        for idx in range(data_length):
            progerss_bar.update(1)
            dataset[idx]['text']= self.clean_data(dataset[idx]['text'])
            if not len(dataset[idx]['text'].split(' ')) <=20:
                data_rem.append(idx)
            else:
                counter+=1
        print(f'\n Samples removed: {counter}\n Dataset decreased by: {100/data_length*counter}\n Current dataset length: {len(data_rem)}')
        return dataset.select(data_rem)
 

Ask a Question

Want to reply to this thread or ask your own question?

You'll need to choose a username for the site, which only take a couple of moments. After that, you can post your question and our members will help you out.

Ask a Question

Members online

No members online now.

Forum statistics

Threads
473,920
Messages
2,570,038
Members
46,449
Latest member
onedumbsquirrel

Latest Threads

Top