Torch Ensemble-Based Microscopic Fungi Classification

The main purpose of this code was to build a deep learning PyTorch Kfolds Ensemble model for classifying images of microscopic fungi into different classes. I structured the project to include both training and prediction functionalities in the app dir , where I later deployed using Flask and Swagger to serve the model via an API that could be tested using Postman.

Here’s a breakdown of everything I did

Device Configuration: First, I set up the device configuration. The code is designed to utilize the mps (Metal Performance Shaders) on macOS if available, otherwise falling back to CUDA (for GPUs) or the CPU.

Configuration Setup: I defined the config dictionary, which holds all the hyperparameters and configurations required for training and testing the model. This includes settings like the base directory for the dataset, batch size, number of epochs, image dimensions, learning rate, and more.

FungiDataset Class: I created a custom dataset class called FungiDataset that loads and preprocesses the images. This class also handles the organization of the dataset into different classes. The images are loaded from the specified directory, and transformations like resizing and normalization are applied.

CustomCNN Model: I defined a convolutional neural network (CNN) model called CustomCNN. This model consists of several convolutional layers followed by fully connected layers. The architecture is simple yet effective for image classification tasks.

Data Augmentation: I used the albumentations library for data augmentation and preprocessing. The get_transforms function returns a composition of transformations, including random cropping, flipping, rotation, and normalization, to make the model more robust.

Training Process: I implemented the training process in the train_model function. Here, I used stratified K-fold cross-validation to ensure the model trains on a balanced distribution of classes across different folds. I also incorporated techniques like learning rate scheduling and early stopping to optimize training.

Model Checkpointing: During training, I saved the model checkpoints for each fold whenever it achieved the best validation loss. This ensures that the best version of the model for each fold is saved and can be used later for prediction.

Ensemble Prediction: After training, I wrote a function called load_fold_models to load the best models from each fold. I then used these models to make ensemble predictions on new images. The predictions from each model are averaged to get the final prediction, which improves accuracy and generalization.

Preprocessing and Prediction: I created a preprocess_image function that loads and preprocesses a single image, preparing it for input into the trained models. I then used the predict_ensemble function to get the final class prediction for the image.

API Deployment: Finally, I deployed this model as an API using FastAPI. The FastAPI application allows users to upload an image of microscopic fungi and receive predictions on the class of fungi using an ensemble of trained models. The API is structured with a main route /predict, where users can upload an image file, and the API returns the predicted class for that image. The deployment is tested using Postman, and the FastAPI server is run using Uvicorn.

Tests:The added test cases in test_prediction.py ensure that the ensemble model predictions are working correctly.