Skip to content

bionlproc/Opioid-Stage

 
 

Repository files navigation

Opioid Use Stage Classification Framework

This repository implements a multi-faceted framework for fine-grained classification on Reddit data into six stages of the opioid use stage continum:

  1. Medical Use
  2. Misuse
  3. Addiction
  4. Recovery
  5. Relapse
  6. Not Using

The project supports multiple modeling strategies:

  • Baseline supervised classifiers (DeBERTa, T5)
  • Reasoning Distillation (T5) summarized reasoning + step-by-step reasoning
  • Supervised Contrastive Learning (SCL) pretraining + fine-tuning
  • In-Context Learning (ICL) relabeling with GPT‑5 + downstream training
  • Zero-shot GPT‑5 evaluation

Prerequisites

  • Python 3.9+

1. Clone Repository

git clone https://github.com/vinuekanayake/Opioid-Stage
cd OUD_Stage_Classification

2. Configuration Files

The repository uses a central ConfigLoader to manage YAML configs.

config/paths.yaml

Defines where your data and outputs live

data:
  worker_train_wo_explanation: "data/worker_data/train_wo_explanation.csv"
  worker_eval_wo_explanation: "data/worker_data/eval_wo_explanation.csv"
  expert_eval_wo_explanation: "data/expert_data/eval_wo_explanation.csv"

config/training_configs/*.yaml

Each experiment (baseline, reasoning, scl, icl_relabel, zeroshot, etc.) has a YAML file that describes:

  • Which model to use (DeBERTa, T5, size)
  • Which data split to use (wo vs w explanations)
  • Learning rate, batch size, epochs, etc.
  • Output directories for checkpoints, logs, and predictions

3. Usage

3.1 Baseline Supervised Training

Baseline training uses standard supervised learning on the original worker labels.

python train_baseline.py \
  --model deberta_base \
  --data_type wo

Where:

  • --model: one of deberta_base, deberta_large, t5_3b, t5_11b
  • --data_type:
    • wo: without explanations (use *_wo_explanation.csv)
    • w: with explanations (use *_w_explanation.csv)

3.2 Reasoning Distillation

This method distills rationales (summarized or step-by-step) generated by DeepSeek R1 into T5 student models using a conditional language modeling objective (label + rationale).

# Summarized Reasoning
python train_reasoning.py \
  --model t5_3b \
  --reasoning summarized \
  --data_type wo

# Step-by-Step Reasoning
python train_reasoning.py \
  --model t5_11b \
  --reasoning step-by-step \
  --data_type wo

3.3 Supervised Contrastive Learning (SCL)

SCL pretrains the encoder with a supervised contrastive loss, then fine-tunes for classification.

Pretraining
# DeBERTa-v3-base, wo_explanation
python scl_pretrain.py \
  --model deberta_base \
  --data_type wo
  • Saves:
    • encoder.pth
    • projection.pth to checkpoints/scl_<model>_<data_type>/.
Finetuning
# DeBERTa + SCL-encoder, wo_explanation
python scl_finetune.py \
  --model deberta_base \
  --data_type wo

Optional: load a specific epoch checkpoint:

# DeBERTa + SCL-encoder, wo_explanation
python scl_finetune.py \
  --model deberta_base \
  --data_type wo
  --checkpoint_epoch 25

3.4 In-Context Learning (ICL)

ICL uses GPT‑5 with a small set of expert examples to relabel the worker training data with more expert-like labels.

GPT‑5 ICL Relabeling

python icl_relabel.py \
  --api_key $OPENAI_API_KEY

To process a single ICL set:

python icl_relabel.py \
  --api_key $OPENAI_API_KEY
  --icl_set 1

Majority Voting Across ICL Sets

python icl_majority_vote.py \
  --min_agreement 6
  • Saves:
    • combined_labels_detailed.csv (original + all labels + mode + count)
    • train_icl_relabeled.csv (final relabeled dataset with column label).
  • --min_agreement is optional; e.g., 6 requires at least 6 sets to agree.

Training on ICL-Relabeled Data

python train_icl.py \
  --model deberta_base \
  --data_type wo

3.5 Zero-Shot GPT‑5 Evaluation

Zero-shot evaluation: GPT‑5 predicts labels directly on test posts without any fine-tuning.

GPT‑5 ICL Relabeling

python zeroshot_gpt5.py \
  --dataset worker \
  --data_type wo \
  --api_key $OPENAI_API_KEY \
  --save_errors

Where:

  • --dataset: worker, or expert
  • Optionally (--save_errors) saves all misclassified examples to an errors CSV. Flags:
  • --resume: resume from existing predictions (skips already classified rows).
  • --overwrite: overwrite existing CSV and start from scratch.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 100.0%