diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 55bbff88f..ba9e0074f 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -40,6 +40,17 @@ from toolkit.data_transfer_object.data_loader import FileItemDTO from toolkit.stable_diffusion_model import StableDiffusion + +def _open_text(path, mode='r'): + """Open a text file trying utf-8-sig first, falling back to cp1252.""" + try: + f = open(path, mode, encoding='utf-8-sig') + f.read() + f.seek(0) + return f + except UnicodeDecodeError: + return open(path, mode, encoding='cp1252') + accelerator = get_accelerator() # def get_associated_caption_from_img_path(img_path): @@ -148,7 +159,7 @@ def get_caption_item(self: 'AiToolkitDataset', index): default_prompt_path_with_ext = os.path.join(os.path.dirname(img_path), 'default' + ext) if os.path.exists(prompt_path): - with open(prompt_path, 'r', encoding='utf-8') as f: + with _open_text(prompt_path) as f: prompt = f.read() # check if is json if prompt_path.endswith('.json'): @@ -158,11 +169,11 @@ def get_caption_item(self: 'AiToolkitDataset', index): prompt = clean_caption(prompt) elif os.path.exists(default_prompt_path_with_ext): - with open(default_prompt_path, 'r', encoding='utf-8') as f: + with _open_text(default_prompt_path) as f: prompt = f.read() prompt = clean_caption(prompt) elif os.path.exists(default_prompt_path): - with open(default_prompt_path, 'r', encoding='utf-8') as f: + with _open_text(default_prompt_path) as f: prompt = f.read() prompt = clean_caption(prompt) else: @@ -339,7 +350,7 @@ def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None): short_caption = None if os.path.exists(prompt_path): - with open(prompt_path, 'r', encoding='utf-8') as f: + with _open_text(prompt_path) as f: prompt = f.read() short_caption = None if prompt_path.endswith('.json'): @@ -1611,7 +1622,7 @@ def __init__(self: 'FileItemDTO', *args, **kwargs): caption_path = file_path_no_ext + '.json' if not os.path.exists(caption_path): raise Exception(f"Error: caption file not found for poi: {caption_path}") - with open(caption_path, 'r', encoding='utf-8') as f: + with _open_text(caption_path) as f: json_data = json.load(f) if 'poi' not in json_data: print_acc(f"Warning: poi not found in caption file: {caption_path}")