Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions atroposlib/tests/test_env_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Regression tests for environment module imports.

Ensures every environment module can be imported without errors
(e.g. no stale references to renamed symbols like OpenaiConfig).
"""

import importlib

import pytest


@pytest.mark.parametrize(
"module_path",
[
"environments.sft_loader_server",
"environments.community.ufc_prediction_env.ufc_server",
"environments.community.ufc_prediction_env.ufc_image_env",
],
)
def test_environment_module_imports(module_path):
"""Each environment module should import without ImportError."""
importlib.import_module(module_path)
13 changes: 9 additions & 4 deletions environments/community/ufc_prediction_env/ufc_image_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from PIL import Image
from pydantic import Field

from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer

Expand Down Expand Up @@ -44,7 +49,7 @@ class UFCImageEnv(BaseEnv):
def __init__(
self,
config: UFCImageEnvConfig,
server_configs: List[OpenaiConfig],
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
Expand Down Expand Up @@ -323,7 +328,7 @@ async def evaluate(self, *args, **kwargs):
return

@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
"""Initialize configuration for the environment"""
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
Expand All @@ -343,7 +348,7 @@ def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
)

server_configs = [
OpenaiConfig(
APIServerConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
Expand Down
9 changes: 7 additions & 2 deletions environments/community/ufc_prediction_env/ufc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@

from pydantic import Field

from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer

Expand Down Expand Up @@ -37,7 +42,7 @@ class UFCEnv(BaseEnv):
def __init__(
self,
config: UFCEnvConfig,
server_configs: List[OpenaiConfig],
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
Expand Down
15 changes: 10 additions & 5 deletions environments/sft_loader_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from datasets import load_dataset
from pydantic import Field

from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import Item


Expand Down Expand Up @@ -58,7 +63,7 @@ class SFTEnv(BaseEnv):
def __init__(
self,
config: SFTConfig,
server_configs: List[OpenaiConfig],
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
Expand All @@ -72,7 +77,7 @@ def __init__(
self.last_step = -1

@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
env_config = SFTConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=8,
Expand All @@ -91,7 +96,7 @@ def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]:
max_sft_per_step=8,
)
server_configs = [
OpenaiConfig(
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9001/v1",
api_key="x",
Expand Down Expand Up @@ -235,7 +240,7 @@ async def checkout_formatting():
dataset_column_name="conversations",
),
server_configs=[
OpenaiConfig(
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9001/v1",
api_key="x",
Expand Down