This repository implements a multi-faceted framework for fine-grained classification on Reddit data into six stages of the opioid use stage continum:
- Medical Use
- Misuse
- Addiction
- Recovery
- Relapse
- Not Using
The project supports multiple modeling strategies:
- Baseline supervised classifiers (DeBERTa, T5)
- Reasoning Distillation (T5) summarized reasoning + step-by-step reasoning
- Supervised Contrastive Learning (SCL) pretraining + fine-tuning
- In-Context Learning (ICL) relabeling with GPT‑5 + downstream training
- Zero-shot GPT‑5 evaluation
- Python 3.9+
git clone https://github.com/vinuekanayake/Opioid-Stage
cd OUD_Stage_ClassificationThe repository uses a central ConfigLoader to manage YAML configs.
Defines where your data and outputs live
data:
worker_train_wo_explanation: "data/worker_data/train_wo_explanation.csv"
worker_eval_wo_explanation: "data/worker_data/eval_wo_explanation.csv"
expert_eval_wo_explanation: "data/expert_data/eval_wo_explanation.csv"Each experiment (baseline, reasoning, scl, icl_relabel, zeroshot, etc.) has a YAML file that describes:
- Which model to use (DeBERTa, T5, size)
- Which data split to use (wo vs w explanations)
- Learning rate, batch size, epochs, etc.
- Output directories for checkpoints, logs, and predictions
Baseline training uses standard supervised learning on the original worker labels.
python train_baseline.py \
--model deberta_base \
--data_type wo
Where:
--model: one ofdeberta_base,deberta_large,t5_3b,t5_11b--data_type:wo: without explanations (use*_wo_explanation.csv)w: with explanations (use*_w_explanation.csv)
This method distills rationales (summarized or step-by-step) generated by DeepSeek R1 into T5 student models using a conditional language modeling objective (label + rationale).
# Summarized Reasoning
python train_reasoning.py \
--model t5_3b \
--reasoning summarized \
--data_type wo
# Step-by-Step Reasoning
python train_reasoning.py \
--model t5_11b \
--reasoning step-by-step \
--data_type woSCL pretrains the encoder with a supervised contrastive loss, then fine-tunes for classification.
# DeBERTa-v3-base, wo_explanation
python scl_pretrain.py \
--model deberta_base \
--data_type wo- Saves:
encoder.pthprojection.pthtocheckpoints/scl_<model>_<data_type>/.
# DeBERTa + SCL-encoder, wo_explanation
python scl_finetune.py \
--model deberta_base \
--data_type woOptional: load a specific epoch checkpoint:
# DeBERTa + SCL-encoder, wo_explanation
python scl_finetune.py \
--model deberta_base \
--data_type wo
--checkpoint_epoch 25ICL uses GPT‑5 with a small set of expert examples to relabel the worker training data with more expert-like labels.
python icl_relabel.py \
--api_key $OPENAI_API_KEYTo process a single ICL set:
python icl_relabel.py \
--api_key $OPENAI_API_KEY
--icl_set 1python icl_majority_vote.py \
--min_agreement 6- Saves:
combined_labels_detailed.csv(original + all labels + mode + count)train_icl_relabeled.csv(final relabeled dataset with column label).
--min_agreementis optional; e.g.,6requires at least6sets to agree.
python train_icl.py \
--model deberta_base \
--data_type woZero-shot evaluation: GPT‑5 predicts labels directly on test posts without any fine-tuning.
python zeroshot_gpt5.py \
--dataset worker \
--data_type wo \
--api_key $OPENAI_API_KEY \
--save_errorsWhere:
--dataset:worker, orexpert- Optionally (
--save_errors) saves all misclassified examples to an errors CSV. Flags: --resume: resume from existing predictions (skips already classified rows).--overwrite: overwrite existing CSV and start from scratch.