diff --git a/atroposlib/tests/test_env_imports.py b/atroposlib/tests/test_env_imports.py new file mode 100644 index 000000000..103da5113 --- /dev/null +++ b/atroposlib/tests/test_env_imports.py @@ -0,0 +1,22 @@ +"""Regression tests for environment module imports. + +Ensures every environment module can be imported without errors +(e.g. no stale references to renamed symbols like OpenaiConfig). +""" + +import importlib + +import pytest + + +@pytest.mark.parametrize( + "module_path", + [ + "environments.sft_loader_server", + "environments.community.ufc_prediction_env.ufc_server", + "environments.community.ufc_prediction_env.ufc_image_env", + ], +) +def test_environment_module_imports(module_path): + """Each environment module should import without ImportError.""" + importlib.import_module(module_path) diff --git a/environments/community/ufc_prediction_env/ufc_image_env.py b/environments/community/ufc_prediction_env/ufc_image_env.py index a63f21db2..f824e8ceb 100644 --- a/environments/community/ufc_prediction_env/ufc_image_env.py +++ b/environments/community/ufc_prediction_env/ufc_image_env.py @@ -10,7 +10,12 @@ from PIL import Image from pydantic import Field -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) from atroposlib.type_definitions import GameHistory, Item from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -44,7 +49,7 @@ class UFCImageEnv(BaseEnv): def __init__( self, config: UFCImageEnvConfig, - server_configs: List[OpenaiConfig], + server_configs: List[APIServerConfig], slurm=True, testing=False, ): @@ -323,7 +328,7 @@ async def evaluate(self, *args, **kwargs): return @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: """Initialize configuration for the environment""" if not os.environ.get("OPENAI_API_KEY"): print("ERROR: OPENAI_API_KEY environment variable is not set!") @@ -343,7 +348,7 @@ def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="gpt-4o", base_url=None, api_key=os.environ.get("OPENAI_API_KEY"), diff --git a/environments/community/ufc_prediction_env/ufc_server.py b/environments/community/ufc_prediction_env/ufc_server.py index f127449dc..935cad320 100644 --- a/environments/community/ufc_prediction_env/ufc_server.py +++ b/environments/community/ufc_prediction_env/ufc_server.py @@ -7,7 +7,12 @@ from pydantic import Field -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) from atroposlib.type_definitions import GameHistory, Item from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer @@ -37,7 +42,7 @@ class UFCEnv(BaseEnv): def __init__( self, config: UFCEnvConfig, - server_configs: List[OpenaiConfig], + server_configs: List[APIServerConfig], slurm=True, testing=False, ): diff --git a/environments/sft_loader_server.py b/environments/sft_loader_server.py index af0811d02..46277dd5f 100644 --- a/environments/sft_loader_server.py +++ b/environments/sft_loader_server.py @@ -5,7 +5,12 @@ from datasets import load_dataset from pydantic import Field -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + ScoredDataGroup, +) from atroposlib.type_definitions import Item @@ -58,7 +63,7 @@ class SFTEnv(BaseEnv): def __init__( self, config: SFTConfig, - server_configs: List[OpenaiConfig], + server_configs: List[APIServerConfig], slurm=True, testing=False, ): @@ -72,7 +77,7 @@ def __init__( self.last_step = -1 @classmethod - def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_config = SFTConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", group_size=8, @@ -91,7 +96,7 @@ def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: max_sft_per_step=8, ) server_configs = [ - OpenaiConfig( + APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9001/v1", api_key="x", @@ -235,7 +240,7 @@ async def checkout_formatting(): dataset_column_name="conversations", ), server_configs=[ - OpenaiConfig( + APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9001/v1", api_key="x",