LoRA-PEFT on MedQuad Dataset with Flan-T5 , Serving with Ray & Docker

I focused on building a robust Question Answering (QA) system for the medical domain using the LoRA (Low-Rank Adaptation) PEFT (Parameter-Efficient Fine-Tuning) approach on the MedQuad dataset with the Flan-T5 model. I integrated modern machine learning and deployment tools such as Ray Serve and Docker to create an efficient and scalable system.

Here’s a detailed breakdown of everything I accomplished :

1. Data Preparation

I started by loading the MedQuad dataset, which contains medical questions and answers. I used Pandas to read the dataset from a CSV file and performed some preprocessing steps to get it ready for training:

  • I split the dataset into training, validation, and test sets using train_test_split from Scikit-Learn, ensuring a balanced distribution and avoiding any data leakage. I reset the indices to ensure clean data inputs.
  • Using Hugging Face’s datasets library, I converted these Pandas DataFrames into Dataset objects and then created a DatasetDict to manage the training, validation, and test datasets.
  •  

2. Model Initialization and Tokenization

I chose the Flan-T5 model, a powerful variant of the T5 transformer optimized for sequence-to-sequence tasks. 

  • Loaded the tokenizer and model from Hugging Face’s model hub using T5Tokenizer and T5ForConditionalGeneration.
  • Specified a prompt prefix: “Assuming you are working as a Doctor. Please answer this question: “, which helped guide the model to generate more domain-specific answers.
  • Defined a preprocessing function to tokenize both the input questions and target answers, handling truncation and padding to maintain consistent input sizes.

3. Applying LoRA (Low-Rank Adaptation)

To fine-tune the Flan-T5 model efficiently, I applied LoRA (Low-Rank Adaptation) using the PEFT framework:

  • I configured LoRA settings with parameters like rank (r=32), scaling factor (lora_alpha=16), dropout (lora_dropout=0.01), and specified that LoRA should target the query (q) and value (v) projection layers of the model. These configurations allowed me to optimize only parts of the model, making the fine-tuning more parameter-efficient.
  • I utilized the get_peft_model function to integrate LoRA into the pre-trained Flan-T5 model, resulting in a compact, fine-tuned model that performs well on domain-specific tasks without excessive computational requirements.

4. Training the Model

I set up the training process using Hugging Face’s Seq2SeqTrainer:

  • Defined the training arguments in Seq2SeqTrainingArguments with settings such as learning rate, batch sizes, weight decay, and the number of epochs. I opted for a batch size of 8 for training and 16 for evaluation, which balanced memory usage and training efficiency.
  • Enabled logging to Weights & Biases (W&B) for tracking the training process and evaluating model performance in real-time.
  • Initiated the training using Seq2SeqTrainer, allowing the model to learn the mapping between questions and answers effectively over 15 epochs.

5.  Saving the Model

I saved the fine-tuned model and tokenizer locally using model.save_pretrained and tokenizer.save_pretrained. This allows easy deployment and reuse without retraining.

6. Deployment with Ray Serve and Docker

To deploy the model as an interactive service, I employed Ray Serve, a scalable model serving library:

  • I wrote an API service in api_serve.py that utilizes Ray Serve to expose the model for generating answers to medical questions. The AnswerGenerator class manages loading the model, preprocessing inputs, generating answers, and sending back responses.
  • I initialized Ray with dashboard support and started the Ray Serve server to handle incoming requests efficiently.
  • To containerize the deployment, I created a Dockerfile that sets up the environment for Ray Serve and Streamlit, exposing the necessary ports and setting up the entry point script start.sh.
  • I added start.sh to orchestrate the starting of Ray Serve, its deployments, and the Streamlit web application sequentially. This script ensures that the environment initializes correctly, preventing any race conditions during startup.

7. Building a Streamlit Web Application

I built a Streamlit front-end application (app.py) to provide a user-friendly interface for the QA system:

  • The app allows users to input a medical question, sends a request to the Ray Serve API, and displays the generated answer in real time. I utilized requests to handle HTTP communication between Streamlit and Ray Serve.
  • Implemented input validation and error handling to enhance the user experience, ensuring that users receive accurate and informative responses to their medical questions.

8. Testing and Demonstration

I tested the deployment by running the entire setup locally with Docker. By entering medical questions into the Streamlit application, I demonstrated how the model could generate relevant answers based on the fine-tuned Flan-T5 with LoRA adaptations.

9. Evaluation

I added an evaluation script to assess the fine-tuned Flan-T5 model with LoRA on the MedQuad dataset. I loaded the pre-trained model, used a validation set to generate answers for 200 samples, and calculated the ROUGE metric scores.