From ed7d29e330df29ad5275d777cc02b6178b43c58f Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Thu, 8 Feb 2024 12:40:27 +0100 Subject: [PATCH 01/23] init poetry packaging --- .gitignore | 166 +++++++ Makefile | 14 + ml_mgie/demo/__init__.py | 0 ml_mgie/demo/inference.py | 92 ++++ ml_mgie/ml_mgie/__init__.py | 0 ml_mgie/ml_mgie/base.py | 12 + ml_mgie/ml_mgie/llava_conversation.py | 367 +++++++++++++++ ml_mgie/ml_mgie/mgie.py | 129 ++++++ ml_mgie/ml_mgie/mgie_llava.py | 625 ++++++++++++++++++++++++++ ml_mgie/ml_mgie/utils.py | 26 ++ pyproject.toml | 33 ++ 11 files changed, 1464 insertions(+) create mode 100644 .gitignore create mode 100644 Makefile create mode 100644 ml_mgie/demo/__init__.py create mode 100644 ml_mgie/demo/inference.py create mode 100644 ml_mgie/ml_mgie/__init__.py create mode 100644 ml_mgie/ml_mgie/base.py create mode 100644 ml_mgie/ml_mgie/llava_conversation.py create mode 100644 ml_mgie/ml_mgie/mgie.py create mode 100644 ml_mgie/ml_mgie/mgie_llava.py create mode 100644 ml_mgie/ml_mgie/utils.py create mode 100644 pyproject.toml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..58f9a56 --- /dev/null +++ b/.gitignore @@ -0,0 +1,166 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Custom +venv* +.DS_Store +*.pt +*.tar.gz \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..60f1d8e --- /dev/null +++ b/Makefile @@ -0,0 +1,14 @@ +PROJECT_NAME := ml-mgie + +.PHONY: venv +venv: + python -m venv venv_${PROJECT_NAME} && echo "run: source venv_${PROJECT_NAME}/bin/activate" + +.PHONY: check_type +check_type: ## run mypy + python -m mypy --config-file=./mypy.ini ./${PROJECT_NAME}/${PROJECT_NAME} + +.PHONY: tests +tests: check_type + python -u -m pytest ./${PROJECT_NAME}/tests -vv -s + diff --git a/ml_mgie/demo/__init__.py b/ml_mgie/demo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ml_mgie/demo/inference.py b/ml_mgie/demo/inference.py new file mode 100644 index 0000000..5ea4860 --- /dev/null +++ b/ml_mgie/demo/inference.py @@ -0,0 +1,92 @@ +import os +import torch +import tqdm +from PIL import Image + +from ml_mgie.mgie import ( + MGIE, + MGIEParams, + DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_PATCH_TOKEN, + DEFAULT_IM_END_TOKEN, +) + +from ml_mgie.utils import remove_alter, crop_resize +from ml_mgie.llava_conversation import conv_templates + +SEED = 13331 +mgie = MGIE(MGIEParams(device="cuda")) + +ins = [ + "make the frame red", + "turn the day into night", + "give him a beard", + "make cottage a mansion", + "remove yellow object from dogs paws", + "change the hair from red to blue", + "remove the text", + "increase the image contrast", + "remove the people in the background", + "please make this photo professional looking", + "darken the image, sharpen it", + "photoshop the girl out", + "make more brightness", + "take away the brown filter form the image", + "add more contrast to simulate more light", + "dark on rgb", + "make the face happy", + "change view as ocean", + "replace basketball with soccer ball", + "let the floor be made of wood", +] +for i in tqdm(range(len(ins))): + img, txt = Image.open("_input/%d.jpg" % (i)).convert("RGB"), ins[i] + + img = mgie.image_processor.preprocess(img, return_tensors="pt")["pixel_values"][0] + txt = "what will this image be like if '%s'" % (txt) + txt = ( + txt + + "\n" + + DEFAULT_IM_START_TOKEN + + DEFAULT_IMAGE_PATCH_TOKEN * mgie.image_token_len + + DEFAULT_IM_END_TOKEN + ) + conv = conv_templates["vicuna_v1_1"].copy() + conv.append_message(conv.roles[0], txt), conv.append_message(conv.roles[1], None) + txt = conv.get_prompt() + txt = mgie.tokenizer(txt) + txt, mask = torch.as_tensor(txt["input_ids"]), T.as_tensor(txt["attention_mask"]) + + with torch.inference_mode(): + out = mgie.model.generate( + txt.unsqueeze(dim=0).cuda(), + images=img.half().unsqueeze(dim=0).cuda(), + attention_mask=mask.unsqueeze(dim=0).cuda(), + do_sample=False, + max_new_tokens=96, + num_beams=1, + no_repeat_ngram_size=3, + return_dict_in_generate=True, + output_hidden_states=True, + ) + out, hid = ( + out["sequences"][0].tolist(), + torch.cat([x[-1] for x in out["hidden_states"]], dim=1)[0], + ) + + p = min(out.index(32003) - 1 if 32003 in out else len(hid) - 9, len(hid) - 9) + hid = hid[p : p + 8] + + out = remove_alter(mgie.tokenizer.decode(out)) + emb = mgie.model.edit_head(hid.unsqueeze(dim=0), mgie.emb) + res = mgie.pipe( + image=Image.open("_input/%d.jpg" % (i)).convert("RGB"), + prompt_embeds=emb, + negative_prompt_embeds=mgie.null, + generator=torch.Generator(device="cuda").manual_seed(SEED), + ).images[0] + + input = Image.open("_input/%d.jpg" % (i)).convert("RGB") + os.makedirs("_output", exist_ok=True) + input.save("_output/in-%d.jpg" % (i)) + Image.fromarray(res).save("_output/out-%d.jpg" % (i)) diff --git a/ml_mgie/ml_mgie/__init__.py b/ml_mgie/ml_mgie/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ml_mgie/ml_mgie/base.py b/ml_mgie/ml_mgie/base.py new file mode 100644 index 0000000..64c3c13 --- /dev/null +++ b/ml_mgie/ml_mgie/base.py @@ -0,0 +1,12 @@ +import torch + + +def get_default_device() -> torch.device: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + return torch.device("cpu") + + +DEFAULT_DEVICE = get_default_device() diff --git a/ml_mgie/ml_mgie/llava_conversation.py b/ml_mgie/ml_mgie/llava_conversation.py new file mode 100644 index 0000000..05198fc --- /dev/null +++ b/ml_mgie/ml_mgie/llava_conversation.py @@ -0,0 +1,367 @@ +import dataclasses +from enum import auto, Enum +from typing import List, Tuple + + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + MPT = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + skip_next: bool = False + + def get_prompt(self): + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + if self.sep_style == SeparatorStyle.MPT: + ret = self.system + self.sep + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role, message): + self.messages.append([role, message]) + + def get_images(self, return_pil=False): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + from PIL import Image + msg, image, image_process_mode = msg + if image_process_mode == "Pad": + def expand2square(pil_img, background_color=(122, 116, 104)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image) + elif image_process_mode == "Crop": + pass + elif image_process_mode == "Resize": + image = image.resize((224, 224)) + else: + raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + if return_pil: + images.append(image) + else: + buffered = BytesIO() + image.save(buffered, format="JPEG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + images.append(img_b64_str) + return images + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + msg, image, image_process_mode = msg + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + # image = image.resize((224, 224)) + buffered = BytesIO() + image.save(buffered, format="JPEG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + img_str = f'user upload image' + msg = msg.replace('', img_str) + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + + +conv_v1 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "Give three tips for staying healthy."), + ("Assistant", + "Sure, here are three tips for staying healthy:\n" + "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. " + "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, " + "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or " + "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening " + "activities at least two days per week.\n" + "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, " + "vegetables, whole grains, lean proteins, and healthy fats can help support " + "your overall health. Try to limit your intake of processed and high-sugar foods, " + "and aim to drink plenty of water throughout the day.\n" + "3. Get enough sleep: Getting enough quality sleep is essential for your physical " + "and mental health. Adults should aim for seven to nine hours of sleep per night. " + "Establish a regular sleep schedule and try to create a relaxing bedtime routine to " + "help improve the quality of your sleep.") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_v1_2 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "What are the key differences between renewable and non-renewable energy sources?"), + ("Assistant", + "Renewable energy sources are those that can be replenished naturally in a relatively " + "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " + "Non-renewable energy sources, on the other hand, are finite and will eventually be " + "depleted, such as coal, oil, and natural gas. Here are some key differences between " + "renewable and non-renewable energy sources:\n" + "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " + "energy sources are finite and will eventually run out.\n" + "2. Environmental impact: Renewable energy sources have a much lower environmental impact " + "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " + "and other negative effects.\n" + "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " + "have lower operational costs than non-renewable sources.\n" + "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " + "locations than non-renewable sources.\n" + "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " + "situations and needs, while non-renewable sources are more rigid and inflexible.\n" + "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " + "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_vicuna_v1_1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_mpt = Conversation( + system="""<|im_start|>system +- You are a helpful language and vision assistant. +- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. +- You should follow the instructions carefully and explain your answers in detail.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_mpt_text = Conversation( + system="""<|im_start|>system +- You are a helpful assistant chatbot trained by MosaicML. +- You answer questions. +- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_bair_v1 = Conversation( + system="BEGINNING OF CONVERSATION:", + roles=("USER", "GPT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +simple_conv = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "Hi!"), + ("Assistant", "Hi there! How can I help you today?") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +simple_conv_multimodal = Conversation( + system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab." + "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail.", + roles=("Human", "Assistant"), + messages=( + ("Human", "Hi!"), + ("Assistant", "Hi there! How can I help you today?\n") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +simple_conv_mpt_multimodal = Conversation( + system="""<|im_start|>system +- You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. +- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. +- You should follow the instructions carefully and explain your answers in detail.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +simple_conv_legacy = Conversation( + system="You are LLaVA, a large language model trained by UW Madison WAIV Lab." + "You are designed to assist human with a variety of tasks using natural language." + "Follow the instructions carefully.", + roles=("Human", "Assistant"), + messages=( + ("Human", "Hi!\n\n### Response:"), + ("Assistant", "Hi there! How can I help you today?\n") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_llava_v1 = Conversation( + system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab." + "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +default_conversation = conv_v1_2 +conv_templates = { + "default": conv_v1_2, + "simple": simple_conv, + "simple_legacy": simple_conv_legacy, + "multimodal": simple_conv_multimodal, + "mpt_multimodal": simple_conv_mpt_multimodal, + "llava_v1": conv_llava_v1, + + # fastchat + "v1": conv_v1_2, + "bair_v1": conv_bair_v1, + "vicuna_v1_1": conv_vicuna_v1_1, + "mpt": conv_mpt, + "mpt_text": conv_mpt_text, +} + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) diff --git a/ml_mgie/ml_mgie/mgie.py b/ml_mgie/ml_mgie/mgie.py new file mode 100644 index 0000000..eb1e67d --- /dev/null +++ b/ml_mgie/ml_mgie/mgie.py @@ -0,0 +1,129 @@ +import os +import torch +from typing import Tuple +from dataclasses import dataclass +from tqdm.auto import tqdm + +from PIL import Image + +import transformers +import diffusers + +# from llava.conversation import conv_templates +# from llava.model import * + +from .mgie_llava import LlavaLlamaForCausalLM +from .base import DEFAULT_DEVICE + +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" +PATH_LLAVA = "./_ckpt/LLaVA-7B-v1" + + +@dataclass +class MGIEParams: + device: torch.device = DEFAULT_DEVICE + + +class MGIE: + def __init__(self, params: MGIEParams) -> None: + self.params = params + self.tokenizer: transformers.AutoTokenizer = None + self.model: LlavaLlamaForCausalLM = None + self.image_processor: transformers.CLIPImageProcessor = None + self.image_token_len: int = None + self.emb: torch.Tensor = None + self._get_model() + self.pipe = self._get_pipe() + + def _get_model( + self, + ) -> Tuple[ + LlavaLlamaForCausalLM, + transformers.AutoTokenizer, + transformers.CLIPImageProcessor, + int, + ]: + tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA) + model = LlavaLlamaForCausalLM.from_pretrained( + PATH_LLAVA, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + use_cache=True, + ).to(self.params.device) + image_processor = transformers.CLIPImageProcessor.from_pretrained( + model.config.mm_vision_tower, torch_dtype=torch.float16 + ) + + tokenizer.padding_side = "left" + tokenizer.add_tokens( + [ + "[IMG0]", + "[IMG1]", + "[IMG2]", + "[IMG3]", + "[IMG4]", + "[IMG5]", + "[IMG6]", + "[IMG7]", + ], + special_tokens=True, + ) + model.resize_token_embeddings(len(tokenizer)) + ckpt = torch.load("./_ckpt/mgie_7b/mllm.pt", map_location="cpu") # TO DEVICE? + model.load_state_dict(ckpt, strict=False) + + mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + if mm_use_im_start_end: + tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) + + vision_tower = model.get_model().vision_tower[0] + vision_tower = transformers.CLIPVisionModel.from_pretrained( + vision_tower.config._name_or_path, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).to(self.params.device) + model.get_model().vision_tower[0] = vision_tower + vision_config = vision_tower.config + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] + vision_config.use_im_start_end = mm_use_im_start_end + if mm_use_im_start_end: + vision_config.im_start_token, vision_config.im_end_token = ( + tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) + ) + image_token_len: int = ( + vision_config.image_size // vision_config.patch_size + ) ** 2 + + _ = model.eval() + EMB = ckpt["emb"].to(self.params.device) + with torch.inference_mode(): + NULL = model.edit_head( + torch.zeros(1, 8, 4096).half().to(self.params.device), EMB + ) + print("NULL:", NULL.shape) + self.model = model + self.tokenizer = tokenizer + self.image_processor = image_processor + self.image_token_len = image_token_len + self.emb = EMB + self.null = NULL + + def _get_pipe(self) -> diffusers.StableDiffusionInstructPix2PixPipeline: + pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained( + "timbrooks/instruct-pix2pix", torch_dtype=T.float16, safety_checker=None + ).to("cuda") + pipe.set_progress_bar_config(disable=True) + pipe.unet.load_state_dict( + torch.load("./_ckpt/mgie_7b/unet.pt", map_location="cpu") + ) # TO DEVICE? + return pipe diff --git a/ml_mgie/ml_mgie/mgie_llava.py b/ml_mgie/ml_mgie/mgie_llava.py new file mode 100644 index 0000000..3b615dd --- /dev/null +++ b/ml_mgie/ml_mgie/mgie_llava.py @@ -0,0 +1,625 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +# +# modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/model/llava.py + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + LlamaConfig, + LlamaModel, + LlamaForCausalLM, + CLIPVisionModel, + CLIPImageProcessor, +) + +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) + +import os, diffusers + +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" +REGISTER_NAME = "mgie-llava" + + +class LlavaConfig(LlamaConfig): + model_type = REGISTER_NAME + + +class LlavaLlamaModel(LlamaModel): + config_class = LlavaConfig + + def __init__(self, config: LlamaConfig): + super(LlavaLlamaModel, self).__init__(config) + + if hasattr(config, "mm_vision_tower"): + # HACK: for FSDP + self.vision_tower = [ + CLIPVisionModel.from_pretrained(config.mm_vision_tower) + ] + # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower) + + if hasattr(config, "use_mm_proj"): + self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size) + + def get_vision_tower(self): + vision_tower = getattr(self, "vision_tower", None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def initialize_vision_modules( + self, + vision_tower, + mm_vision_select_layer, + pretrain_mm_mlp_adapter=None, + fsdp=None, + ): + self.config.mm_vision_tower = vision_tower + + image_processor = CLIPImageProcessor.from_pretrained(vision_tower) + + if not hasattr(self, "vision_tower"): + vision_tower = CLIPVisionModel.from_pretrained(vision_tower) + else: + vision_tower = self.vision_tower[0] + vision_tower.requires_grad_(False) + + if fsdp is not None and len(fsdp) > 0: + self.vision_tower = [vision_tower] + else: + self.vision_tower = vision_tower + + vision_config = vision_tower.config + num_patches = (vision_config.image_size // vision_config.patch_size) ** 2 + + self.config.use_mm_proj = True + self.config.mm_hidden_size = vision_config.hidden_size + self.config.mm_vision_select_layer = mm_vision_select_layer + + if not hasattr(self, "mm_projector"): + self.mm_projector = nn.Linear( + vision_config.hidden_size, self.config.hidden_size + ) + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load( + pretrain_mm_mlp_adapter, map_location="cpu" + ) + self.mm_projector.load_state_dict( + {k.split(".")[-1]: v for k, v in mm_projector_weights.items()} + ) + + return dict( + image_processor=image_processor, + image_token_len=num_patches, + vision_config=vision_config, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + # HACK: replace back original embeddings for LLaVA pretraining + orig_embeds_params = getattr(self, "orig_embeds_params", None) + # if orig_embeds_params is not None: + # orig_embeds_params = orig_embeds_params[0] + # with torch.no_grad(): + # self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + vision_tower = self.get_vision_tower() + if ( + vision_tower is not None + and (input_ids.shape[1] != 1 or self.training) + and images is not None + ): + # TODO: this is a modified multimodal LLM -- Haotian Liu + with torch.no_grad(): + if type(images) is list: + # variable length images + image_features = [] + for image in images: + image_forward_out = vision_tower( + image.unsqueeze(0), output_hidden_states=True + ) + select_hidden_state_layer = getattr( + self.config, "mm_vision_select_layer", -1 + ) + select_hidden_state = image_forward_out.hidden_states[ + select_hidden_state_layer + ] + image_feature = select_hidden_state[:, 1:] + image_features.append(image_feature) + else: + image_forward_outs = vision_tower( + images.to(vision_tower.dtype), output_hidden_states=True + ) + select_hidden_state_layer = getattr( + self.config, "mm_vision_select_layer", -1 + ) + select_hidden_state = image_forward_outs.hidden_states[ + select_hidden_state_layer + ] + image_features = select_hidden_state[:, 1:].to(images.dtype) + if type(images) is list: + image_features = [ + self.mm_projector(image_feature)[0] + for image_feature in image_features + ] + else: + image_features = self.mm_projector(image_features) + dummy_image_features = torch.zeros( + 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + dummy_image_features = self.mm_projector(dummy_image_features) + + new_input_embeds = [] + cur_image_idx = 0 + for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): + if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0: + # multimodal LLM, but the current sample is not multimodal + cur_input_embeds = ( + cur_input_embeds + (0.0 * dummy_image_features).sum() + ) + new_input_embeds.append(cur_input_embeds) + cur_image_idx += 1 + continue + if vision_tower.config.use_im_start_end: + cur_image_features = image_features[cur_image_idx] + num_patches = cur_image_features.shape[0] + if (cur_input_ids == vision_tower.config.im_start_token).sum() != ( + cur_input_ids == vision_tower.config.im_end_token + ).sum(): + raise ValueError( + "The number of image start tokens and image end tokens should be the same." + ) + image_start_tokens = torch.where( + cur_input_ids == vision_tower.config.im_start_token + )[0] + for image_start_token_pos in image_start_tokens: + cur_image_features = image_features[cur_image_idx].to( + device=cur_input_embeds.device + ) + num_patches = cur_image_features.shape[0] + if ( + cur_input_ids[image_start_token_pos + num_patches + 1] + != vision_tower.config.im_end_token + ): + raise ValueError( + "The image end token should follow the image start token." + ) + if orig_embeds_params is not None: + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:image_start_token_pos].detach(), + cur_input_embeds[ + image_start_token_pos : image_start_token_pos + + 1 + ], + cur_image_features, + cur_input_embeds[ + image_start_token_pos + + num_patches + + 1 : image_start_token_pos + + num_patches + + 2 + ], + cur_input_embeds[ + image_start_token_pos + num_patches + 2 : + ].detach(), + ), + dim=0, + ) + else: + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[: image_start_token_pos + 1], + cur_image_features, + cur_input_embeds[ + image_start_token_pos + num_patches + 1 : + ], + ), + dim=0, + ) + cur_image_idx += 1 + new_input_embeds.append(cur_new_input_embeds) + else: + cur_image_features = image_features[cur_image_idx] + num_patches = cur_image_features.shape[0] + if ( + cur_input_ids == vision_tower.config.im_patch_token + ).sum() != num_patches: + raise ValueError( + "The number of image patch tokens should be the same as the number of image patches." + ) + masked_indices = torch.where( + cur_input_ids == vision_tower.config.im_patch_token + )[0] + mask_index_start = masked_indices[0] + if ( + masked_indices + != torch.arange( + mask_index_start, + mask_index_start + num_patches, + device=masked_indices.device, + dtype=masked_indices.dtype, + ) + ).any(): + raise ValueError( + "The image patch tokens should be consecutive." + ) + if orig_embeds_params is not None: + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:mask_index_start].detach(), + cur_image_features, + cur_input_embeds[ + mask_index_start + num_patches : + ].detach(), + ), + dim=0, + ) + else: + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:mask_index_start], + cur_image_features, + cur_input_embeds[mask_index_start + num_patches :], + ), + dim=0, + ) + new_input_embeds.append(cur_new_input_embeds) + cur_image_idx += 1 + inputs_embeds = torch.stack(new_input_embeds, dim=0) + + return super(LlavaLlamaModel, self).forward( + input_ids=None, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class EditMapper(nn.Module): + def __init__(self): + super().__init__() + + self.llm2hid = nn.Linear(4096, 512) + self.query = nn.Parameter(torch.randn(1, 77, 512)) + self.mapper = nn.Transformer( + batch_first=True, + norm_first=True, + d_model=512, + nhead=4, + num_encoder_layers=4, + num_decoder_layers=4, + dim_feedforward=2048, + dropout=0.0, + ) + self.hid2feat = nn.Linear(512, 768) + + def forward(self, llm, emb): + hid = self.llm2hid(llm + emb) + hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1)) + feat = self.hid2feat(hid) + + return feat + + +class LlavaLlamaForCausalLM(LlamaForCausalLM): + config_class = LlavaConfig + + def __init__(self, config): + super(LlamaForCausalLM, self).__init__(config) + self.model = LlavaLlamaModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.edit_head = EditMapper() + + self.scheduler, self.vae, self.unet = [ + diffusers.DDPMScheduler.from_pretrained( + "runwayml/stable-diffusion-v1-5", subfolder="scheduler" + ), + diffusers.AutoencoderKL.from_pretrained( + "runwayml/stable-diffusion-v1-5", subfolder="vae" + ), + diffusers.UNet2DConditionModel.from_pretrained( + "runwayml/stable-diffusion-v1-5", subfolder="unet" + ), + ] + self.vae.requires_grad_(False) + self.unet.register_to_config(in_channels=8) + with torch.no_grad(): + conv = torch.nn.Conv2d( + 8, + self.unet.conv_in.out_channels, + self.unet.conv_in.kernel_size, + self.unet.conv_in.stride, + self.unet.conv_in.padding, + ) + conv.weight.zero_() + conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight) + self.unet.conv_in = conv + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + def get_vision_tower(self): + model = self.get_model() + vision_tower = model.vision_tower + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + p2p_inp=None, + p2p_ans=None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + images=images, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model/pipeline parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if labels is not None: + llm = [] + for i in range(labels.shape[0]): + try: + p = labels[i].data.cpu().tolist().index(32003) - 1 + except: + p = len(labels[i]) - 9 + p = min(len(hidden_states[i]) - 9, p) + llm.append(hidden_states[i][p : p + 8].unsqueeze(0)) + llm = torch.cat(llm, dim=0) + hid_edit = self.edit_head( + llm, + self.model.embed_tokens.weight[-8:] + .unsqueeze(dim=0) + .repeat(labels.shape[0], 1, 1), + ) + + B, DROP = labels.shape[0], 0.05 + + hid_null = self.edit_head( + torch.zeros(B, 8, 4096, device=labels.device), + self.model.embed_tokens.weight[-8:] + .unsqueeze(dim=0) + .repeat(labels.shape[0], 1, 1), + ) + + with torch.no_grad(): + lat_ans, lat_inp = ( + self.vae.encode(p2p_ans).latent_dist.sample() + * self.vae.config.scaling_factor, + self.vae.encode(p2p_inp).latent_dist.mode(), + ) + lat_ans, lat_inp = [ + torch.from_numpy(lat_ans.data.cpu().float().numpy()).to( + lat_ans.device + ), + torch.from_numpy(lat_inp.data.cpu().float().numpy()).to( + lat_inp.device + ), + ] + + noise = torch.randn_like(lat_ans) + ts = torch.randint( + 0, self.scheduler.config.num_train_timesteps, (B,), device=noise.device + ).long() + lat_noise = self.scheduler.add_noise(lat_ans, noise, ts) + + prob = torch.rand(B, device=lat_ans.device) + mask = (prob < (DROP * 2)).reshape(B, 1, 1) + hid_edit = torch.where(mask, hid_null, hid_edit) + mask = ( + 1.0 + - ( + (prob >= DROP).to(lat_inp.dtype) + * (prob < (DROP * 3)).to(lat_inp.dtype) + ) + ).reshape(B, 1, 1, 1) + lat_inp *= mask + + out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample + + loss_ce, loss_edit = loss, nn.functional.mse_loss( + out, noise, reduction="mean" + ) + if int(os.environ["LOCAL_RANK"]) == 0: + print("loss_ce:", loss_ce, "/", "loss_edit:", loss_edit) + loss = loss_ce + loss_edit * 0.5 + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "images": kwargs.get("images", None), + } + ) + return model_inputs + + def initialize_vision_tokenizer( + self, + mm_use_im_start_end, + tokenizer, + device, + tune_mm_mlp_adapter=False, + pretrain_mm_mlp_adapter=None, + ): + vision_config = self.get_vision_tower().config + vision_config.use_im_start_end = mm_use_im_start_end + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) + self.resize_token_embeddings(len(tokenizer)) + vision_config.im_start_token, vision_config.im_end_token = ( + tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) + ) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if tune_mm_mlp_adapter: + self.get_model().orig_embeds_params = [ + self.get_input_embeddings().weight.data.clone().to(device=device) + ] + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load( + pretrain_mm_mlp_adapter, map_location="cpu" + ) + embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[ + -num_new_tokens: + ] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError( + f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}." + ) + + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] + + +AutoConfig.register(REGISTER_NAME, LlavaConfig) +AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) diff --git a/ml_mgie/ml_mgie/utils.py b/ml_mgie/ml_mgie/utils.py new file mode 100644 index 0000000..12eda69 --- /dev/null +++ b/ml_mgie/ml_mgie/utils.py @@ -0,0 +1,26 @@ +# TODO: add typing +def crop_resize(f, sz=512): + w, h = f.size + if w > h: + p = (w - h) // 2 + f = f.crop([p, 0, p + h, h]) + elif h > w: + p = (h - w) // 2 + f = f.crop([0, p, w, p + w]) + f = f.resize([sz, sz]) + return f + + +def remove_alter(s): # hack expressive instruction + if "ASSISTANT:" in s: + s = s[s.index("ASSISTANT:") + 10 :].strip() + if "" in s: + s = s[: s.index("")].strip() + if "alternative" in s.lower(): + s = s[: s.lower().index("alternative")] + if "[IMG0]" in s: + s = s[: s.index("[IMG0]")] + s = ".".join([s.strip() for s in s.split(".")[:2]]) + if s[-1] != ".": + s += "." + return s.strip() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..37345ae --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,33 @@ +[tool.poetry] +name = "ml-mgie" +packages = [ + { include = "ml_mgie", from = "./ml_mgie" }, + { include = "llava", from = "./LLaVa" }, +] +version = "v0.0.0" +description = "" +authors = [ + "Tsu-Jui Fu", + "Wenze Hu", + "Xianzhi Du", + "William Yang Wang", + "Yinfei Yang", + "Zhe Gan", + "Paul Asquin", # only as package contributor +] + +[build-system] +requires = ["poetry-core", "setuptools", "setuptools-scm"] +build-backend = "poetry.core.masonry.api" + +[tool.setuptools_scm] + +[tool.poetry.dependencies] +python = ">=3.10,<3.11" +torch = "^2.2.0" +tqdm = "^4.66.1" +transformers = "^4.37.2" +diffusers = "^0.26.2" +sentencepiece = "^0.1.99" +protobuf = "^4.25.2" +accelerate = "^0.26.1" From 77b69b4da4d6eafde00afeab14612a8f47be835c Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Thu, 8 Feb 2024 12:43:02 +0100 Subject: [PATCH 02/23] add package readme --- ml_mgie/README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 ml_mgie/README.md diff --git a/ml_mgie/README.md b/ml_mgie/README.md new file mode 100644 index 0000000..1f3bdfc --- /dev/null +++ b/ml_mgie/README.md @@ -0,0 +1,16 @@ +# ML-MGIE Packaging + +**Work In Progress**: package ml-mgie, simplify dependencies, make compatible with MPS, CUDA and CPU + +Packaging contributors: +- Paul Asquin + +## Installation +```bash +poetry install +``` + +## Demo +```bash +poetry run python ml_mgie/demo/inference.py +``` \ No newline at end of file From 770e8ad856480234764dc754a0ecbfad160e42c3 Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Thu, 8 Feb 2024 23:23:04 +0100 Subject: [PATCH 03/23] preco and wip for mps --- mgie_llava.py | 398 ++++++++++++++----- mgie_train.py | 528 ++++++++++++++++---------- ml_mgie/demo/inference.py | 71 ++-- ml_mgie/ml_mgie/llava_conversation.py | 81 ++-- ml_mgie/ml_mgie/mgie.py | 76 ++-- ml_mgie/ml_mgie/mgie_llava.py | 34 +- 6 files changed, 786 insertions(+), 402 deletions(-) diff --git a/mgie_llava.py b/mgie_llava.py index 3d2ea7c..1fa65fc 100644 --- a/mgie_llava.py +++ b/mgie_llava.py @@ -4,20 +4,26 @@ # # modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/model/llava.py +import os from typing import List, Optional, Tuple, Union +import diffusers import torch import torch.nn as nn -import torch.nn.functional as F from torch.nn import CrossEntropyLoss - -from transformers import AutoConfig, AutoModelForCausalLM, \ - LlamaConfig, LlamaModel, LlamaForCausalLM, \ - CLIPVisionModel, CLIPImageProcessor - -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast - -import os, diffusers +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + CLIPImageProcessor, + CLIPVisionModel, + LlamaConfig, + LlamaForCausalLM, + LlamaModel, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" @@ -37,25 +43,32 @@ def __init__(self, config: LlamaConfig): if hasattr(config, "mm_vision_tower"): # HACK: for FSDP - self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)] + self.vision_tower = [ + CLIPVisionModel.from_pretrained(config.mm_vision_tower) + ] # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower) if hasattr(config, "use_mm_proj"): self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size) def get_vision_tower(self): - vision_tower = getattr(self, 'vision_tower', None) + vision_tower = getattr(self, "vision_tower", None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower - def initialize_vision_modules(self, vision_tower, mm_vision_select_layer, - pretrain_mm_mlp_adapter=None, fsdp=None): + def initialize_vision_modules( + self, + vision_tower, + mm_vision_select_layer, + pretrain_mm_mlp_adapter=None, + fsdp=None, + ): self.config.mm_vision_tower = vision_tower image_processor = CLIPImageProcessor.from_pretrained(vision_tower) - if not hasattr(self, 'vision_tower'): + if not hasattr(self, "vision_tower"): vision_tower = CLIPVisionModel.from_pretrained(vision_tower) else: vision_tower = self.vision_tower[0] @@ -73,17 +86,23 @@ def initialize_vision_modules(self, vision_tower, mm_vision_select_layer, self.config.mm_hidden_size = vision_config.hidden_size self.config.mm_vision_select_layer = mm_vision_select_layer - if not hasattr(self, 'mm_projector'): - self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size) + if not hasattr(self, "mm_projector"): + self.mm_projector = nn.Linear( + vision_config.hidden_size, self.config.hidden_size + ) if pretrain_mm_mlp_adapter is not None: - mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') - self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()}) + mm_projector_weights = torch.load( + pretrain_mm_mlp_adapter, map_location="cpu" + ) + self.mm_projector.load_state_dict( + {k.split(".")[-1]: v for k, v in mm_projector_weights.items()} + ) return dict( image_processor=image_processor, image_token_len=num_patches, - vision_config=vision_config + vision_config=vision_config, ) def forward( @@ -98,9 +117,8 @@ def forward( images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - # HACK: replace back original embeddings for LLaVA pretraining - orig_embeds_params = getattr(self, 'orig_embeds_params', None) + orig_embeds_params = getattr(self, "orig_embeds_params", None) # if orig_embeds_params is not None: # orig_embeds_params = orig_embeds_params[0] # with torch.no_grad(): @@ -110,28 +128,49 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) vision_tower = self.get_vision_tower() - if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: + if ( + vision_tower is not None + and (input_ids.shape[1] != 1 or self.training) + and images is not None + ): # TODO: this is a modified multimodal LLM -- Haotian Liu with torch.no_grad(): if type(images) is list: # variable length images image_features = [] for image in images: - image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True) - select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1) - select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer] + image_forward_out = vision_tower( + image.unsqueeze(0), output_hidden_states=True + ) + select_hidden_state_layer = getattr( + self.config, "mm_vision_select_layer", -1 + ) + select_hidden_state = image_forward_out.hidden_states[ + select_hidden_state_layer + ] image_feature = select_hidden_state[:, 1:] image_features.append(image_feature) else: - image_forward_outs = vision_tower(images.to(vision_tower.dtype), output_hidden_states=True) - select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1) - select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer] + image_forward_outs = vision_tower( + images.to(vision_tower.dtype), output_hidden_states=True + ) + select_hidden_state_layer = getattr( + self.config, "mm_vision_select_layer", -1 + ) + select_hidden_state = image_forward_outs.hidden_states[ + select_hidden_state_layer + ] image_features = select_hidden_state[:, 1:].to(images.dtype) if type(images) is list: - image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features] + image_features = [ + self.mm_projector(image_feature)[0] + for image_feature in image_features + ] else: image_features = self.mm_projector(image_features) - dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) + dummy_image_features = torch.zeros( + 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) dummy_image_features = self.mm_projector(dummy_image_features) new_input_embeds = [] @@ -139,69 +178,158 @@ def forward( for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0: # multimodal LLM, but the current sample is not multimodal - cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() + cur_input_embeds = ( + cur_input_embeds + (0.0 * dummy_image_features).sum() + ) new_input_embeds.append(cur_input_embeds) cur_image_idx += 1 continue if vision_tower.config.use_im_start_end: cur_image_features = image_features[cur_image_idx] num_patches = cur_image_features.shape[0] - if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum(): - raise ValueError("The number of image start tokens and image end tokens should be the same.") - image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0] + if (cur_input_ids == vision_tower.config.im_start_token).sum() != ( + cur_input_ids == vision_tower.config.im_end_token + ).sum(): + raise ValueError( + "The number of image start tokens and image end tokens should be the same." + ) + image_start_tokens = torch.where( + cur_input_ids == vision_tower.config.im_start_token + )[0] for image_start_token_pos in image_start_tokens: - cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device) + cur_image_features = image_features[cur_image_idx].to( + device=cur_input_embeds.device + ) num_patches = cur_image_features.shape[0] - if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token: - raise ValueError("The image end token should follow the image start token.") + if ( + cur_input_ids[image_start_token_pos + num_patches + 1] + != vision_tower.config.im_end_token + ): + raise ValueError( + "The image end token should follow the image start token." + ) if orig_embeds_params is not None: - cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0) + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:image_start_token_pos].detach(), + cur_input_embeds[ + image_start_token_pos : image_start_token_pos + + 1 + ], + cur_image_features, + cur_input_embeds[ + image_start_token_pos + + num_patches + + 1 : image_start_token_pos + + num_patches + + 2 + ], + cur_input_embeds[ + image_start_token_pos + num_patches + 2 : + ].detach(), + ), + dim=0, + ) else: - cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0) + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[: image_start_token_pos + 1], + cur_image_features, + cur_input_embeds[ + image_start_token_pos + num_patches + 1 : + ], + ), + dim=0, + ) cur_image_idx += 1 new_input_embeds.append(cur_new_input_embeds) else: cur_image_features = image_features[cur_image_idx] num_patches = cur_image_features.shape[0] - if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches: - raise ValueError("The number of image patch tokens should be the same as the number of image patches.") - masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0] + if ( + cur_input_ids == vision_tower.config.im_patch_token + ).sum() != num_patches: + raise ValueError( + "The number of image patch tokens should be the same as the number of image patches." + ) + masked_indices = torch.where( + cur_input_ids == vision_tower.config.im_patch_token + )[0] mask_index_start = masked_indices[0] - if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any(): - raise ValueError("The image patch tokens should be consecutive.") + if ( + masked_indices + != torch.arange( + mask_index_start, + mask_index_start + num_patches, + device=masked_indices.device, + dtype=masked_indices.dtype, + ) + ).any(): + raise ValueError( + "The image patch tokens should be consecutive." + ) if orig_embeds_params is not None: - cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0) + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:mask_index_start].detach(), + cur_image_features, + cur_input_embeds[ + mask_index_start + num_patches : + ].detach(), + ), + dim=0, + ) else: - cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0) + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:mask_index_start], + cur_image_features, + cur_input_embeds[mask_index_start + num_patches :], + ), + dim=0, + ) new_input_embeds.append(cur_new_input_embeds) cur_image_idx += 1 inputs_embeds = torch.stack(new_input_embeds, dim=0) return super(LlavaLlamaModel, self).forward( - input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, - inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict + input_ids=None, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) + class EditMapper(nn.Module): def __init__(self): super().__init__() self.llm2hid = nn.Linear(4096, 512) self.query = nn.Parameter(torch.randn(1, 77, 512)) - self.mapper = nn.Transformer(batch_first=True, norm_first=True, - d_model=512, nhead=4, num_encoder_layers=4, num_decoder_layers=4, - dim_feedforward=2048, dropout=0.0) + self.mapper = nn.Transformer( + batch_first=True, + norm_first=True, + d_model=512, + nhead=4, + num_encoder_layers=4, + num_decoder_layers=4, + dim_feedforward=2048, + dropout=0.0, + ) self.hid2feat = nn.Linear(512, 768) def forward(self, llm, emb): - hid = self.llm2hid(llm+emb) + hid = self.llm2hid(llm + emb) hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1)) feat = self.hid2feat(hid) return feat + class LlavaLlamaForCausalLM(LlamaForCausalLM): config_class = LlavaConfig @@ -213,13 +341,27 @@ def __init__(self, config): self.edit_head = EditMapper() - self.scheduler, self.vae, self.unet = [diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler'), - diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae'), - diffusers.UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet')] + self.scheduler, self.vae, self.unet = [ + diffusers.DDPMScheduler.from_pretrained( + "runwayml/stable-diffusion-v1-5", subfolder="scheduler" + ), + diffusers.AutoencoderKL.from_pretrained( + "runwayml/stable-diffusion-v1-5", subfolder="vae" + ), + diffusers.UNet2DConditionModel.from_pretrained( + "runwayml/stable-diffusion-v1-5", subfolder="unet" + ), + ] self.vae.requires_grad_(False) self.unet.register_to_config(in_channels=8) with torch.no_grad(): - conv = torch.nn.Conv2d(8, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding) + conv = torch.nn.Conv2d( + 8, + self.unet.conv_in.out_channels, + self.unet.conv_in.kernel_size, + self.unet.conv_in.stride, + self.unet.conv_in.padding, + ) conv.weight.zero_() conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight) self.unet.conv_in = conv @@ -252,13 +394,22 @@ def forward( output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, - p2p_inp=None, p2p_ans=None + p2p_inp=None, + p2p_ans=None, ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -270,7 +421,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - images=images + images=images, ) hidden_states = outputs[0] @@ -292,38 +443,70 @@ def forward( if labels is not None: llm = [] for i in range(labels.shape[0]): - try: p = labels[i].data.cpu().tolist().index(32003)-1 - except: p = len(labels[i])-9 - p = min(len(hidden_states[i])-9, p) - llm.append(hidden_states[i][p:p+8].unsqueeze(0)) + try: + p = labels[i].data.cpu().tolist().index(32003) - 1 + except: + p = len(labels[i]) - 9 + p = min(len(hidden_states[i]) - 9, p) + llm.append(hidden_states[i][p : p + 8].unsqueeze(0)) llm = torch.cat(llm, dim=0) - hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1)) + hid_edit = self.edit_head( + llm, + self.model.embed_tokens.weight[-8:] + .unsqueeze(dim=0) + .repeat(labels.shape[0], 1, 1), + ) B, DROP = labels.shape[0], 0.05 - hid_null = self.edit_head(torch.zeros(B, 8, 4096, device=labels.device), - self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1)) + hid_null = self.edit_head( + torch.zeros(B, 8, 4096, device=labels.device), + self.model.embed_tokens.weight[-8:] + .unsqueeze(dim=0) + .repeat(labels.shape[0], 1, 1), + ) with torch.no_grad(): - lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample()*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode() - lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device), - torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)] + lat_ans, lat_inp = ( + self.vae.encode(p2p_ans).latent_dist.sample() + * self.vae.config.scaling_factor, + self.vae.encode(p2p_inp).latent_dist.mode(), + ) + lat_ans, lat_inp = [ + torch.from_numpy(lat_ans.data.cpu().float().numpy()).to( + lat_ans.device + ), + torch.from_numpy(lat_inp.data.cpu().float().numpy()).to( + lat_inp.device + ), + ] noise = torch.randn_like(lat_ans) - ts = torch.randint(0, self.scheduler.config.num_train_timesteps, (B, ), device=noise.device).long() + ts = torch.randint( + 0, self.scheduler.config.num_train_timesteps, (B,), device=noise.device + ).long() lat_noise = self.scheduler.add_noise(lat_ans, noise, ts) prob = torch.rand(B, device=lat_ans.device) - mask = (prob<(DROP*2)).reshape(B, 1, 1) + mask = (prob < (DROP * 2)).reshape(B, 1, 1) hid_edit = torch.where(mask, hid_null, hid_edit) - mask = (1.0-((prob>=DROP).to(lat_inp.dtype)*(prob<(DROP*3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1) + mask = ( + 1.0 + - ( + (prob >= DROP).to(lat_inp.dtype) + * (prob < (DROP * 3)).to(lat_inp.dtype) + ) + ).reshape(B, 1, 1, 1) lat_inp *= mask out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample - loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean') - if int(os.environ['LOCAL_RANK'])==0: print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit) - loss = loss_ce+loss_edit*0.5 + loss_ce, loss_edit = loss, nn.functional.mse_loss( + out, noise, reduction="mean" + ) + if int(os.environ["LOCAL_RANK"]) == 0: + print("loss_ce:", loss_ce, "/", "loss_edit:", loss_edit) + loss = loss_ce + loss_edit * 0.5 if not return_dict: output = (logits,) + outputs[1:] @@ -338,7 +521,12 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, ): if past_key_values: input_ids = input_ids[:, -1:] @@ -359,49 +547,75 @@ def prepare_inputs_for_generation( ) return model_inputs - def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device, - tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None): + def initialize_vision_tokenizer( + self, + mm_use_im_start_end, + tokenizer, + device, + tune_mm_mlp_adapter=False, + pretrain_mm_mlp_adapter=None, + ): vision_config = self.get_vision_tower().config vision_config.use_im_start_end = mm_use_im_start_end tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if mm_use_im_start_end: - num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + num_new_tokens = tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) self.resize_token_embeddings(len(tokenizer)) - vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( - dim=0, keepdim=True) + dim=0, keepdim=True + ) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( - dim=0, keepdim=True) + dim=0, keepdim=True + ) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if tune_mm_mlp_adapter: - self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)] + self.get_model().orig_embeds_params = [ + self.get_input_embeddings().weight.data.clone().to(device=device) + ] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = False if pretrain_mm_mlp_adapter: - mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') - embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] + mm_projector_weights = torch.load( + pretrain_mm_mlp_adapter, map_location="cpu" + ) + embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: - input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + input_embeddings[-num_new_tokens:] = embed_tokens_weight[ + -num_new_tokens: + ] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: - raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + raise ValueError( + f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}." + ) + + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] - vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] AutoConfig.register("llava", LlavaConfig) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) diff --git a/mgie_train.py b/mgie_train.py index 46edf2d..f0e6948 100644 --- a/mgie_train.py +++ b/mgie_train.py @@ -4,25 +4,21 @@ # # modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/train/train.py -import os import copy -from dataclasses import dataclass, field import json import logging +import os import pathlib -from typing import Dict, Optional, Sequence, List +from dataclasses import dataclass, field +from typing import Dict, Optional, Sequence import torch - import transformers -from torch.utils.data import Dataset -from llava.train.llava_trainer import LLaVATrainer - from llava import conversation as conversation_lib from llava.model import * - +from llava.train.llava_trainer import LLaVATrainer from PIL import Image -import torch.nn as nn +from torch.utils.data import Dataset # TODO: import and use code from ../data/dataset.py @@ -36,22 +32,34 @@ DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" -import io, base64, pickle, random -from tqdm import tqdm +import base64 +import io +import pickle +import random + import numpy as np +from tqdm import tqdm + + +def b2f(b): + return Image.open(io.BytesIO(base64.b64decode(b))).convert("RGB") + -def b2f(b): return Image.open(io.BytesIO(base64.b64decode(b))).convert('RGB') def resize(f): w, h = f.size - if w>h: - p = (w-h)//2 - f = f.crop([p, 0, p+h, h]) - elif h>w: - p = (h-w)//2 - f = f.crop([0, p, w, p+w]) + if w > h: + p = (w - h) // 2 + f = f.crop([p, 0, p + h, h]) + elif h > w: + p = (h - w) // 2 + f = f.crop([0, p, w, p + w]) f = f.resize([512, 512]) return f -def img2npy(f): return (2.0*np.array(f)/255.0-1.0).transpose((2, 0, 1)).astype(np.float32) + + +def img2npy(f): + return (2.0 * np.array(f) / 255.0 - 1.0).transpose((2, 0, 1)).astype(np.float32) + @dataclass class ModelArguments: @@ -60,21 +68,24 @@ class ModelArguments: freeze_backbone: bool = field(default=False) tune_mm_mlp_adapter: bool = field(default=False) vision_tower: Optional[str] = field(default=None) - mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer + mm_vision_select_layer: Optional[int] = field( + default=-1 + ) # default to the last layer pretrain_mm_mlp_adapter: Optional[str] = field(default=None) mm_use_im_start_end: bool = field(default=False) @dataclass class DataArguments: - data_path: str = field(default=None, - metadata={"help": "Path to the training data."}) + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) lazy_preprocess: bool = False is_multimodal: bool = False sep_image_conv_front: bool = False image_token_len: int = 0 image_folder: Optional[str] = field(default=None) - image_aspect_ratio: str = 'square' + image_aspect_ratio: str = "square" @dataclass @@ -87,22 +98,22 @@ class TrainingArguments(transformers.TrainingArguments): model_max_length: int = field( default=512, metadata={ - "help": - "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) double_quant: bool = field( default=True, - metadata={"help": "Compress the quantization statistics through double quantization."} + metadata={ + "help": "Compress the quantization statistics through double quantization." + }, ) quant_type: str = field( default="nf4", - metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} - ) - bits: int = field( - default=16, - metadata={"help": "How many bits to use."} + metadata={ + "help": "Quantization data type to use. Should be one of `fp4` or `nf4`." + }, ) + bits: int = field(default=16, metadata={"help": "How many bits to use."}) lora_enable: bool = False lora_r: int = 64 lora_alpha: int = 16 @@ -114,10 +125,13 @@ class TrainingArguments(transformers.TrainingArguments): def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: - logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + logging.warning( + f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}" + ) with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: @@ -155,7 +169,9 @@ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} - to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + to_return = { + k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() + } return to_return @@ -164,17 +180,15 @@ def find_all_linear_names(model): lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls): - names = name.split('.') + names = name.split(".") lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - - if 'lm_head' in lora_module_names: # needed for 16-bit - lora_module_names.remove('lm_head') + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") return list(lora_module_names) -def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, - output_dir: str): +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" if trainer.deepspeed: torch.cuda.synchronize() @@ -183,10 +197,7 @@ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, state_dict = trainer.model.state_dict() if trainer.args.should_save: - cpu_state_dict = { - key: value.cpu() - for key, value in state_dict.items() - } + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa @@ -208,16 +219,19 @@ def smart_tokenizer_and_embedding_resize( output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( - dim=0, keepdim=True) + dim=0, keepdim=True + ) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( - dim=0, keepdim=True) + dim=0, keepdim=True + ) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg -def _tokenize_fn(strings: Sequence[str], - tokenizer: transformers.PreTrainedTokenizer) -> Dict: +def _tokenize_fn( + strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer +) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( @@ -226,11 +240,10 @@ def _tokenize_fn(strings: Sequence[str], padding="longest", max_length=tokenizer.model_max_length, truncation=True, - ) for text in strings - ] - input_ids = labels = [ - tokenized.input_ids[0] for tokenized in tokenized_list + ) + for text in strings ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list @@ -250,7 +263,7 @@ def _mask_targets(target, tokenized_lens, speakers): target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": - target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX + target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len @@ -266,9 +279,10 @@ def _add_speaker_and_signal(header, source, get_conversation=True): elif from_str.lower() == "gpt": from_str = conversation_lib.default_conversation.roles[1] else: - from_str = 'unknown' - sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + - sentence["value"] + END_SIGNAL) + from_str = "unknown" + sentence["value"] = ( + BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL + ) if get_conversation: conversation += sentence["value"] conversation += BEGIN_SIGNAL @@ -280,22 +294,34 @@ def preprocess_multimodal( multimodal_cfg: dict, cur_token_len: int, ) -> Dict: - is_multimodal = multimodal_cfg['is_multimodal'] + is_multimodal = multimodal_cfg["is_multimodal"] # image_token_len = multimodal_cfg['image_token_len'] image_token_len = cur_token_len if not is_multimodal: return sources for source in sources: - if multimodal_cfg['sep_image_conv_front']: - assert DEFAULT_IMAGE_TOKEN in source[0]['value'] - source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() - source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value'] + if multimodal_cfg["sep_image_conv_front"]: + assert DEFAULT_IMAGE_TOKEN in source[0]["value"] + source[0]["value"] = ( + source[0]["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() + ) + source[0]["value"] = ( + DEFAULT_IMAGE_TOKEN + + conversation_lib.default_conversation.sep + + conversation_lib.default_conversation.roles[0] + + ": " + + source[0]["value"] + ) for sentence in source: replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len - if multimodal_cfg['use_im_start_end']: - replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN - sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + if multimodal_cfg["use_im_start_end"]: + replace_token = ( + DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + ) + sentence["value"] = sentence["value"].replace( + DEFAULT_IMAGE_TOKEN, replace_token + ) return sources @@ -370,6 +396,7 @@ def preprocess_v1( labels=targets, ) + def preprocess_mpt( sources, tokenizer: transformers.PreTrainedTokenizer, @@ -408,9 +435,11 @@ def preprocess_mpt( total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep) - re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt + re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt for conv_idx in range(3, len(rounds), 2): - re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt + re_rounds.append( + conv.sep.join(rounds[conv_idx : conv_idx + 2]) + ) # user + gpt cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): @@ -421,7 +450,9 @@ def preprocess_mpt( if len(parts) != 2: break parts[0] += sep - round_len = len(tokenizer(rou).input_ids) + len(tokenizer(conv.sep).input_ids) + round_len = len(tokenizer(rou).input_ids) + len( + tokenizer(conv.sep).input_ids + ) instruction_len = len(tokenizer(parts[0]).input_ids) target[cur_len : cur_len + instruction_len] = IGNORE_INDEX @@ -468,8 +499,9 @@ def preprocess( input_ids = conversations_tokenized["input_ids"] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): - tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], - tokenizer)["input_ids_lens"] + tokenized_lens = _tokenize_fn( + [header] + [s["value"] for s in source], tokenizer + )["input_ids_lens"] speakers = [sentence["from"] for sentence in source] _mask_targets(target, tokenized_lens, speakers) @@ -479,8 +511,7 @@ def preprocess( class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" - def __init__(self, data_path: str, - tokenizer: transformers.PreTrainedTokenizer): + def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer): super(SupervisedDataset, self).__init__() logging.warning("Loading data...") list_data_dict = json.load(open(data_path, "r")) @@ -500,17 +531,21 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: class LazySupervisedDataset(Dataset): - - def __init__(self, data_path: str, - tokenizer: transformers.PreTrainedTokenizer, - multimodal_cfg: dict): + def __init__( + self, + data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + multimodal_cfg: dict, + ): super(LazySupervisedDataset, self).__init__() self.tokenizer, self.multimodal_cfg = tokenizer, multimodal_cfg - self.pkl, self.prompt = pickle.load(open('./_data/ipr2pr.pkl', 'rb'))['task'], json.load(open('./_data/ipr2pr_expressive.json', 'r')) + self.pkl, self.prompt = pickle.load(open("./_data/ipr2pr.pkl", "rb"))[ + "task" + ], json.load(open("./_data/ipr2pr_expressive.json", "r")) random.shuffle(self.pkl) - print('--pkl: %d--'%(len(self.pkl))) + print("--pkl: %d--" % (len(self.pkl))) def __len__(self): return len(self.pkl) @@ -518,27 +553,41 @@ def __len__(self): def __getitem__(self, i) -> Dict[str, torch.Tensor]: item = self.pkl[i][0] - tsv = open('./_data/ipr2pr.tsv', 'r') - tsv.seek(item['lineidx']) - b = tsv.readline().strip().split('\t') + tsv = open("./_data/ipr2pr.tsv", "r") + tsv.seek(item["lineidx"]) + b = tsv.readline().strip().split("\t") image = resize(b2f(b[0])) - processor = self.multimodal_cfg['image_processor'] - image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + processor = self.multimodal_cfg["image_processor"] + image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] - cur_token_len = (image.shape[1]//14)*(image.shape[2]//14) - query = "what will this image be like if '%s'\n%s"%(item['instruction'], DEFAULT_IMAGE_TOKEN) - ans = '%s [IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]'%(self.prompt[item['input']]['expressive']) - sources = preprocess_multimodal(copy.deepcopy([[{'from': 'human', 'value': query}, {'from': 'gpt', 'value': ans}]]), - self.multimodal_cfg, cur_token_len) + cur_token_len = (image.shape[1] // 14) * (image.shape[2] // 14) + query = "what will this image be like if '%s'\n%s" % ( + item["instruction"], + DEFAULT_IMAGE_TOKEN, + ) + ans = "%s [IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]" % ( + self.prompt[item["input"]]["expressive"] + ) + sources = preprocess_multimodal( + copy.deepcopy( + [[{"from": "human", "value": query}, {"from": "gpt", "value": ans}]] + ), + self.multimodal_cfg, + cur_token_len, + ) data_dict = preprocess(sources, self.tokenizer) - if isinstance(i, int): data_dict = dict(input_ids=data_dict['input_ids'][0], - labels=data_dict['labels'][0]) - data_dict['image'] = image + if isinstance(i, int): + data_dict = dict( + input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0] + ) + data_dict["image"] = image - p2p_inp, p2p_ans = img2npy(resize(b2f(b[0])).resize([256, 256])), img2npy(resize(b2f(b[1])).resize([256, 256])) - data_dict['p2p_inp'], data_dict['p2p_ans'] = p2p_inp, p2p_ans + p2p_inp, p2p_ans = img2npy(resize(b2f(b[0])).resize([256, 256])), img2npy( + resize(b2f(b[1])).resize([256, 256]) + ) + data_dict["p2p_inp"], data_dict["p2p_ans"] = p2p_inp, p2p_ans return data_dict @@ -550,97 +599,119 @@ class DataCollatorForSupervisedDataset(object): tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: - input_ids, labels = tuple([instance[key] for instance in instances] - for key in ("input_ids", "labels")) + input_ids, labels = tuple( + [instance[key] for instance in instances] for key in ("input_ids", "labels") + ) input_ids = torch.nn.utils.rnn.pad_sequence( - input_ids, - batch_first=True, - padding_value=self.tokenizer.pad_token_id) - labels = torch.nn.utils.rnn.pad_sequence(labels, - batch_first=True, - padding_value=IGNORE_INDEX) + input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id + ) + labels = torch.nn.utils.rnn.pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX + ) batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) - if 'image' in instances[0]: - images = [instance['image'] for instance in instances] + if "image" in instances[0]: + images = [instance["image"] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): - batch['images'] = torch.stack(images) + batch["images"] = torch.stack(images) else: - batch['images'] = images - - batch['p2p_inp'], batch['p2p_ans'] = [torch.cat([torch.from_numpy(d['p2p_inp']).unsqueeze(dim=0) for d in instances], dim=0), - torch.cat([torch.from_numpy(d['p2p_ans']).unsqueeze(dim=0) for d in instances], dim=0)] + batch["images"] = images + + batch["p2p_inp"], batch["p2p_ans"] = [ + torch.cat( + [torch.from_numpy(d["p2p_inp"]).unsqueeze(dim=0) for d in instances], + dim=0, + ), + torch.cat( + [torch.from_numpy(d["p2p_ans"]).unsqueeze(dim=0) for d in instances], + dim=0, + ), + ] return batch -def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, - data_args) -> Dict: +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_args +) -> Dict: """Make dataset and collator for supervised fine-tuning.""" - dataset_cls = (LazySupervisedDataset - if data_args.lazy_preprocess else SupervisedDataset) - train_dataset = dataset_cls(tokenizer=tokenizer, - data_path=data_args.data_path, - multimodal_cfg=dict( - is_multimodal=data_args.is_multimodal, - sep_image_conv_front=data_args.sep_image_conv_front, - image_token_len=data_args.image_token_len, - image_folder=data_args.image_folder, - image_aspect_ratio=data_args.image_aspect_ratio, - use_im_start_end=getattr(data_args, 'mm_use_im_start_end', False), - image_processor=getattr(data_args, 'image_processor', None))) + dataset_cls = ( + LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset + ) + train_dataset = dataset_cls( + tokenizer=tokenizer, + data_path=data_args.data_path, + multimodal_cfg=dict( + is_multimodal=data_args.is_multimodal, + sep_image_conv_front=data_args.sep_image_conv_front, + image_token_len=data_args.image_token_len, + image_folder=data_args.image_folder, + image_aspect_ratio=data_args.image_aspect_ratio, + use_im_start_end=getattr(data_args, "mm_use_im_start_end", False), + image_processor=getattr(data_args, "image_processor", None), + ), + ) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) - return dict(train_dataset=train_dataset, - eval_dataset=None, - data_collator=data_collator) + return dict( + train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator + ) def train(): - parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) model_args, data_args, training_args = parser.parse_args_into_dataclasses() - compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + compute_dtype = ( + torch.float16 + if training_args.fp16 + else (torch.bfloat16 if training_args.bf16 else torch.float32) + ) bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: - from transformers import BitsAndBytesConfig from peft import prepare_model_for_int8_training - bnb_model_from_pretrained_args.update(dict( - device_map={"": training_args.device}, - load_in_4bit=training_args.bits == 4, - load_in_8bit=training_args.bits == 8, - quantization_config=BitsAndBytesConfig( + from transformers import BitsAndBytesConfig + + bnb_model_from_pretrained_args.update( + dict( + device_map={"": training_args.device}, load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - bnb_4bit_compute_dtype=compute_dtype, - bnb_4bit_use_double_quant=training_args.double_quant, - bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} + quantization_config=BitsAndBytesConfig( + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=training_args.double_quant, + bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'} + ), ) - )) + ) if model_args.vision_tower is not None: - if 'mpt' in model_args.model_name_or_path: + if "mpt" in model_args.model_name_or_path: model = LlavaMPTForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, - **bnb_model_from_pretrained_args + **bnb_model_from_pretrained_args, ) else: model = LlavaLlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, - **bnb_model_from_pretrained_args + **bnb_model_from_pretrained_args, ) else: model = transformers.LlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, - **bnb_model_from_pretrained_args + **bnb_model_from_pretrained_args, ) model.config.use_cache = False @@ -648,19 +719,28 @@ def train(): model.model.requires_grad_(False) if training_args.bits in [4, 8]: - model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) - model = prepare_model_for_int8_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) + model.config.torch_dtype = ( + torch.float32 + if training_args.fp16 + else (torch.bfloat16 if training_args.bf16 else torch.float32) + ) + model = prepare_model_for_int8_training( + model, use_gradient_checkpointing=training_args.gradient_checkpointing + ) if training_args.gradient_checkpointing and model_args.vision_tower is None: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: + def make_inputs_require_grad(module, input, output): output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) if training_args.lora_enable: from peft import LoraConfig, get_peft_model + lora_config = LoraConfig( r=training_args.lora_r, lora_alpha=training_args.lora_alpha, @@ -677,12 +757,12 @@ def make_inputs_require_grad(module, input, output): logging.warning("Adding LoRA adapters...") model = get_peft_model(model, lora_config) - if 'mpt' in model_args.model_name_or_path: + if "mpt" in model_args.model_name_or_path: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, - padding_side="right" + padding_side="right", ) else: tokenizer = transformers.AutoTokenizer.from_pretrained( @@ -701,33 +781,41 @@ def make_inputs_require_grad(module, input, output): model=model, ) if "llama" in model_args.model_name_or_path: - tokenizer.add_special_tokens({ - "eos_token": DEFAULT_EOS_TOKEN, - "bos_token": DEFAULT_BOS_TOKEN, - "unk_token": DEFAULT_UNK_TOKEN, - }) + tokenizer.add_special_tokens( + { + "eos_token": DEFAULT_EOS_TOKEN, + "bos_token": DEFAULT_BOS_TOKEN, + "unk_token": DEFAULT_UNK_TOKEN, + } + ) else: tokenizer.pad_token = tokenizer.unk_token if "mpt" in model_args.model_name_or_path: - conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"] + conversation_lib.default_conversation = conversation_lib.conv_templates[ + "mpt" + ] else: - conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1_1"] + conversation_lib.default_conversation = conversation_lib.conv_templates[ + "vicuna_v1_1" + ] if model_args.vision_tower is not None: model_vision_dict = model.get_model().initialize_vision_modules( vision_tower=model_args.vision_tower, mm_vision_select_layer=model_args.mm_vision_select_layer, pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter, - fsdp=training_args.fsdp + fsdp=training_args.fsdp, ) model.get_vision_tower().to(dtype=torch.float16, device=training_args.device) - vision_config = model_vision_dict['vision_config'] + vision_config = model_vision_dict["vision_config"] - data_args.image_token_len = model_vision_dict['image_token_len'] - data_args.image_processor = model_vision_dict['image_processor'] + data_args.image_token_len = model_vision_dict["image_token_len"] + data_args.image_processor = model_vision_dict["image_processor"] data_args.is_multimodal = True - model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter + model.config.tune_mm_mlp_adapter = ( + training_args.tune_mm_mlp_adapter + ) = model_args.tune_mm_mlp_adapter if model_args.tune_mm_mlp_adapter: model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): @@ -739,74 +827,124 @@ def make_inputs_require_grad(module, input, output): p.requires_grad = False if training_args.bits in [4, 8]: - model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) + model.get_model().mm_projector.to( + dtype=compute_dtype, device=training_args.device + ) - model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end - vision_config.use_im_start_end = training_args.use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_use_im_start_end = ( + data_args.mm_use_im_start_end + ) = model_args.mm_use_im_start_end + vision_config.use_im_start_end = ( + training_args.use_im_start_end + ) = model_args.mm_use_im_start_end model.config.sep_image_conv_front = data_args.sep_image_conv_front - model.initialize_vision_tokenizer(mm_use_im_start_end=model_args.mm_use_im_start_end, tokenizer=tokenizer, device=training_args.device, - tune_mm_mlp_adapter=model_args.tune_mm_mlp_adapter, pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter) + model.initialize_vision_tokenizer( + mm_use_im_start_end=model_args.mm_use_im_start_end, + tokenizer=tokenizer, + device=training_args.device, + tune_mm_mlp_adapter=model_args.tune_mm_mlp_adapter, + pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter, + ) params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad] if len(params_no_grad) > 0: if training_args.fsdp is not None and len(training_args.fsdp) > 0: if len(params_no_grad) < 10: - print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad)) + print( + "[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}".format( + len(params_no_grad), params_no_grad + ) + ) else: - print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10]))) - print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.") - print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining") + print( + "[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)".format( + len(params_no_grad), ", ".join(params_no_grad[:10]) + ) + ) + print( + "[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental." + ) + print( + "[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining" + ) + + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel as FSDP, + ) - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP def patch_FSDP_use_orig_params(func): def wrap_func(*args, **kwargs): - use_orig_params = kwargs.pop('use_orig_params', True) + use_orig_params = kwargs.pop("use_orig_params", True) return func(*args, **kwargs, use_orig_params=use_orig_params) + return wrap_func FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__) if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer + for name, module in model.named_modules(): if isinstance(module, LoraLayer): if training_args.bf16: module = module.to(torch.bfloat16) - if 'norm' in name: + if "norm" in name: module = module.to(torch.float32) - if 'lm_head' in name or 'embed_tokens' in name: - if hasattr(module, 'weight'): + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): if training_args.bf16 and module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) # start for MGIE - os.makedirs('_log', exist_ok=True) + os.makedirs("_log", exist_ok=True) pt = {} - for i in tqdm(range(2)): pt.update(torch.load('./_ckpt/LLaVA-7B-v1/pytorch_model-0000%d-of-00002.bin'%(i+1), map_location='cpu')) + for i in tqdm(range(2)): + pt.update( + torch.load( + "./_ckpt/LLaVA-7B-v1/pytorch_model-0000%d-of-00002.bin" % (i + 1), + map_location="cpu", + ) + ) miss, unexp = model.load_state_dict(pt, strict=False) - print('miss:', miss), print('unexp:', unexp) - - tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True) + print("miss:", miss), print("unexp:", unexp) + + tokenizer.add_tokens( + [ + "[IMG0]", + "[IMG1]", + "[IMG2]", + "[IMG3]", + "[IMG4]", + "[IMG5]", + "[IMG6]", + "[IMG7]", + ], + special_tokens=True, + ) model.resize_token_embeddings(len(tokenizer)) - print(tokenizer), json.dump(tokenizer.get_vocab(), open('_log/vocabs.json', 'w'), indent=2) + print(tokenizer), json.dump( + tokenizer.get_vocab(), open("_log/vocabs.json", "w"), indent=2 + ) for n, p in model.named_parameters(): - if 'embed_tokens' in n or 'lm_head' in n or 'edit_head' in n or 'unet' in n: p.requires_grad = True - else: p.requires_grad = False - with open('_log/parameters.txt', 'w') as F: - for n, p in model.named_parameters(): F.write('%s %s %s\n'%(n, str(p.shape), str(p.requires_grad))) - - with open('_log/args_train.txt', 'w') as F: - for key in vars(training_args): F.write('%s: %s\n'%(str(key), str(vars(training_args)[key]))) + if "embed_tokens" in n or "lm_head" in n or "edit_head" in n or "unet" in n: + p.requires_grad = True + else: + p.requires_grad = False + with open("_log/parameters.txt", "w") as F: + for n, p in model.named_parameters(): + F.write("%s %s %s\n" % (n, str(p.shape), str(p.requires_grad))) + + with open("_log/args_train.txt", "w") as F: + for key in vars(training_args): + F.write("%s: %s\n" % (str(key), str(vars(training_args)[key]))) # end for MGIE - data_module = make_supervised_data_module(tokenizer=tokenizer, - data_args=data_args) - trainer = LLaVATrainer(model=model, - tokenizer=tokenizer, - args=training_args, - **data_module) + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + trainer = LLaVATrainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) @@ -824,10 +962,14 @@ def wrap_func(*args, **kwargs): if training_args.local_rank == 0 or training_args.local_rank == -1: model.config.save_pretrained(training_args.output_dir) model.save_pretrained(training_args.output_dir, state_dict=state_dict) - torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) + torch.save( + non_lora_state_dict, + os.path.join(training_args.output_dir, "non_lora_trainables.bin"), + ) else: - safe_save_model_for_hf_trainer(trainer=trainer, - output_dir=training_args.output_dir) + safe_save_model_for_hf_trainer( + trainer=trainer, output_dir=training_args.output_dir + ) if __name__ == "__main__": diff --git a/ml_mgie/demo/inference.py b/ml_mgie/demo/inference.py index 5ea4860..fb5e7b5 100644 --- a/ml_mgie/demo/inference.py +++ b/ml_mgie/demo/inference.py @@ -1,21 +1,28 @@ import os -import torch -import tqdm -from PIL import Image +import shutil +from pathlib import Path +import torch +from ml_mgie.llava_conversation import conv_templates from ml_mgie.mgie import ( - MGIE, - MGIEParams, + DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, - DEFAULT_IM_END_TOKEN, + MGIE, + MGIEParams, ) - -from ml_mgie.utils import remove_alter, crop_resize -from ml_mgie.llava_conversation import conv_templates +from ml_mgie.utils import crop_resize, remove_alter +from PIL import Image +from tqdm import tqdm SEED = 13331 -mgie = MGIE(MGIEParams(device="cuda")) +CFG_TXT = 7.5 +CFG_IMG = 1.5 +params = MGIEParams() +mgie = MGIE(params=params) +input_path = Path("_input") +output_path = Path("_output") +os.makedirs(output_path, exist_ok=True) ins = [ "make the frame red", @@ -40,28 +47,32 @@ "let the floor be made of wood", ] for i in tqdm(range(len(ins))): - img, txt = Image.open("_input/%d.jpg" % (i)).convert("RGB"), ins[i] + image_input_path = input_path / f"{i}.jpg" + img = crop_resize(Image.open(image_input_path).convert("RGB")) + instruction = ins[i] img = mgie.image_processor.preprocess(img, return_tensors="pt")["pixel_values"][0] - txt = "what will this image be like if '%s'" % (txt) - txt = ( - txt + prompt = f"what will this image be like if {instruction}" + prompt = ( + prompt + "\n" + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * mgie.image_token_len + DEFAULT_IM_END_TOKEN ) conv = conv_templates["vicuna_v1_1"].copy() - conv.append_message(conv.roles[0], txt), conv.append_message(conv.roles[1], None) - txt = conv.get_prompt() - txt = mgie.tokenizer(txt) - txt, mask = torch.as_tensor(txt["input_ids"]), T.as_tensor(txt["attention_mask"]) + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + prompt_tokenized = mgie.tokenizer(prompt) + prompt_tensor_ids = torch.as_tensor(prompt_tokenized["input_ids"]) + mask = torch.as_tensor(prompt_tokenized["attention_mask"]) with torch.inference_mode(): out = mgie.model.generate( - txt.unsqueeze(dim=0).cuda(), - images=img.half().unsqueeze(dim=0).cuda(), - attention_mask=mask.unsqueeze(dim=0).cuda(), + prompt_tensor_ids.unsqueeze(dim=0).to(params.device), + images=img.half().unsqueeze(dim=0).to(params.device), + attention_mask=mask.unsqueeze(dim=0).to(params.device), do_sample=False, max_new_tokens=96, num_beams=1, @@ -79,14 +90,16 @@ out = remove_alter(mgie.tokenizer.decode(out)) emb = mgie.model.edit_head(hid.unsqueeze(dim=0), mgie.emb) - res = mgie.pipe( - image=Image.open("_input/%d.jpg" % (i)).convert("RGB"), + res: Image.Image = mgie.pipe( + image=Image.open(image_input_path).convert("RGB"), prompt_embeds=emb, negative_prompt_embeds=mgie.null, - generator=torch.Generator(device="cuda").manual_seed(SEED), + generator=torch.Generator(device=params.device).manual_seed(SEED), + guidance_scale=CFG_TXT, + image_guidance_scale=CFG_IMG, ).images[0] - - input = Image.open("_input/%d.jpg" % (i)).convert("RGB") - os.makedirs("_output", exist_ok=True) - input.save("_output/in-%d.jpg" % (i)) - Image.fromarray(res).save("_output/out-%d.jpg" % (i)) + # Save results before/after + print(f"Instruction: {instruction}") + print(f"Output: {out}") + shutil.copy(image_input_path, output_path / f"{i}-in.jpg") + res.save(output_path / f"{i}-out.jpg") diff --git a/ml_mgie/ml_mgie/llava_conversation.py b/ml_mgie/ml_mgie/llava_conversation.py index 05198fc..c7d3216 100644 --- a/ml_mgie/ml_mgie/llava_conversation.py +++ b/ml_mgie/ml_mgie/llava_conversation.py @@ -1,10 +1,11 @@ import dataclasses -from enum import auto, Enum -from typing import List, Tuple +from enum import Enum, auto +from typing import List class SeparatorStyle(Enum): """Different separator style.""" + SINGLE = auto() TWO = auto() MPT = auto() @@ -13,6 +14,7 @@ class SeparatorStyle(Enum): @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" + system: str roles: List[str] messages: List[List[str]] @@ -64,33 +66,43 @@ def append_message(self, role, message): def get_images(self, return_pil=False): images = [] - for i, (role, msg) in enumerate(self.messages[self.offset:]): + for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO + from PIL import Image + msg, image, image_process_mode = msg if image_process_mode == "Pad": + def expand2square(pil_img, background_color=(122, 116, 104)): width, height = pil_img.size if width == height: return pil_img elif width > height: - result = Image.new(pil_img.mode, (width, width), background_color) + result = Image.new( + pil_img.mode, (width, width), background_color + ) result.paste(pil_img, (0, (width - height) // 2)) return result else: - result = Image.new(pil_img.mode, (height, height), background_color) + result = Image.new( + pil_img.mode, (height, height), background_color + ) result.paste(pil_img, ((height - width) // 2, 0)) return result + image = expand2square(image) elif image_process_mode == "Crop": pass elif image_process_mode == "Resize": image = image.resize((224, 224)) else: - raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + raise ValueError( + f"Invalid image_process_mode: {image_process_mode}" + ) max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 @@ -113,11 +125,12 @@ def expand2square(pil_img, background_color=(122, 116, 104)): def to_gradio_chatbot(self): ret = [] - for i, (role, msg) in enumerate(self.messages[self.offset:]): + for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO + msg, image, image_process_mode = msg max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw @@ -135,7 +148,7 @@ def to_gradio_chatbot(self): image.save(buffered, format="JPEG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() img_str = f'user upload image' - msg = msg.replace('', img_str) + msg = msg.replace("", img_str) ret.append([msg, None]) else: ret[-1][-1] = msg @@ -149,14 +162,17 @@ def copy(self): offset=self.offset, sep_style=self.sep_style, sep=self.sep, - sep2=self.sep2) + sep2=self.sep2, + ) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, - "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "messages": [ + [x, y[0] if type(y) is tuple else y] for x, y in self.messages + ], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, @@ -173,11 +189,12 @@ def dict(self): conv_v1 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", + "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ("Human", "Give three tips for staying healthy."), - ("Assistant", + ( + "Assistant", "Sure, here are three tips for staying healthy:\n" "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. " "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, " @@ -191,7 +208,8 @@ def dict(self): "3. Get enough sleep: Getting enough quality sleep is essential for your physical " "and mental health. Adults should aim for seven to nine hours of sleep per night. " "Establish a regular sleep schedule and try to create a relaxing bedtime routine to " - "help improve the quality of your sleep.") + "help improve the quality of your sleep.", + ), ), offset=2, sep_style=SeparatorStyle.SINGLE, @@ -200,11 +218,15 @@ def dict(self): conv_v1_2 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", + "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( - ("Human", "What are the key differences between renewable and non-renewable energy sources?"), - ("Assistant", + ( + "Human", + "What are the key differences between renewable and non-renewable energy sources?", + ), + ( + "Assistant", "Renewable energy sources are those that can be replenished naturally in a relatively " "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " "Non-renewable energy sources, on the other hand, are finite and will eventually be " @@ -222,7 +244,8 @@ def dict(self): "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " "situations and needs, while non-renewable sources are more rigid and inflexible.\n" "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " - "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") + "non-renewable sources are not, and their depletion can lead to economic and social instability.\n", + ), ), offset=2, sep_style=SeparatorStyle.SINGLE, @@ -280,12 +303,9 @@ def dict(self): simple_conv = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", + "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), - messages=( - ("Human", "Hi!"), - ("Assistant", "Hi there! How can I help you today?") - ), + messages=(("Human", "Hi!"), ("Assistant", "Hi there! How can I help you today?")), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", @@ -293,12 +313,12 @@ def dict(self): simple_conv_multimodal = Conversation( system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab." - "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." - "Follow the instructions carefully and explain your answers in detail.", + "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail.", roles=("Human", "Assistant"), messages=( ("Human", "Hi!"), - ("Assistant", "Hi there! How can I help you today?\n") + ("Assistant", "Hi there! How can I help you today?\n"), ), offset=2, sep_style=SeparatorStyle.SINGLE, @@ -320,12 +340,12 @@ def dict(self): simple_conv_legacy = Conversation( system="You are LLaVA, a large language model trained by UW Madison WAIV Lab." - "You are designed to assist human with a variety of tasks using natural language." - "Follow the instructions carefully.", + "You are designed to assist human with a variety of tasks using natural language." + "Follow the instructions carefully.", roles=("Human", "Assistant"), messages=( ("Human", "Hi!\n\n### Response:"), - ("Assistant", "Hi there! How can I help you today?\n") + ("Assistant", "Hi there! How can I help you today?\n"), ), offset=2, sep_style=SeparatorStyle.SINGLE, @@ -334,8 +354,8 @@ def dict(self): conv_llava_v1 = Conversation( system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab." - "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." - "Follow the instructions carefully and explain your answers in detail.", + "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail.", roles=("USER", "ASSISTANT"), version="v1", messages=(), @@ -353,7 +373,6 @@ def dict(self): "multimodal": simple_conv_multimodal, "mpt_multimodal": simple_conv_mpt_multimodal, "llava_v1": conv_llava_v1, - # fastchat "v1": conv_v1_2, "bair_v1": conv_bair_v1, diff --git a/ml_mgie/ml_mgie/mgie.py b/ml_mgie/ml_mgie/mgie.py index eb1e67d..6ff2a14 100644 --- a/ml_mgie/ml_mgie/mgie.py +++ b/ml_mgie/ml_mgie/mgie.py @@ -1,25 +1,28 @@ -import os -import torch -from typing import Tuple from dataclasses import dataclass -from tqdm.auto import tqdm - -from PIL import Image +from pathlib import Path -import transformers import diffusers +import torch +import transformers + +from .base import DEFAULT_DEVICE +from .mgie_llava import LlavaLlamaForCausalLM # from llava.conversation import conv_templates # from llava.model import * -from .mgie_llava import LlavaLlamaForCausalLM -from .base import DEFAULT_DEVICE DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" -PATH_LLAVA = "./_ckpt/LLaVA-7B-v1" +PATH_LLAVA = Path("./_ckpt/LLaVA-7B-v1") +PATH_MLLM = Path("./_ckpt/mgie_7b/mllm.pt") +PATH_UNET = Path("./_ckpt/mgie_7b/unet.pt") + +assert PATH_LLAVA.exists() +assert PATH_MLLM.exists() +assert PATH_UNET.exists() @dataclass @@ -28,27 +31,20 @@ class MGIEParams: class MGIE: - def __init__(self, params: MGIEParams) -> None: + def __init__(self, params: MGIEParams = MGIEParams()) -> None: self.params = params self.tokenizer: transformers.AutoTokenizer = None self.model: LlavaLlamaForCausalLM = None self.image_processor: transformers.CLIPImageProcessor = None self.image_token_len: int = None self.emb: torch.Tensor = None - self._get_model() + self._set_model() self.pipe = self._get_pipe() - def _get_model( - self, - ) -> Tuple[ - LlavaLlamaForCausalLM, - transformers.AutoTokenizer, - transformers.CLIPImageProcessor, - int, - ]: + def _set_model(self): tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA) model = LlavaLlamaForCausalLM.from_pretrained( - PATH_LLAVA, + PATH_LLAVA.absolute(), low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True, @@ -72,7 +68,7 @@ def _get_model( special_tokens=True, ) model.resize_token_embeddings(len(tokenizer)) - ckpt = torch.load("./_ckpt/mgie_7b/mllm.pt", map_location="cpu") # TO DEVICE? + ckpt = torch.load(PATH_MLLM, map_location="cpu") # TO DEVICE? model.load_state_dict(ckpt, strict=False) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) @@ -95,35 +91,37 @@ def _get_model( )[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: - vision_config.im_start_token, vision_config.im_end_token = ( - tokenizer.convert_tokens_to_ids( - [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] - ) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] ) - image_token_len: int = ( - vision_config.image_size // vision_config.patch_size - ) ** 2 + image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 - _ = model.eval() - EMB = ckpt["emb"].to(self.params.device) + # model = model.to(self.params.device) + model.eval() + emb = ckpt["emb"].to(self.params.device) with torch.inference_mode(): - NULL = model.edit_head( - torch.zeros(1, 8, 4096).half().to(self.params.device), EMB + null = model.edit_head( + torch.zeros(1, 8, 4096, device=self.params.device, dtype=torch.float16), + emb, ) - print("NULL:", NULL.shape) + print("NULL:", null.shape) self.model = model self.tokenizer = tokenizer self.image_processor = image_processor self.image_token_len = image_token_len - self.emb = EMB - self.null = NULL + self.emb = emb + self.null = null def _get_pipe(self) -> diffusers.StableDiffusionInstructPix2PixPipeline: pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained( - "timbrooks/instruct-pix2pix", torch_dtype=T.float16, safety_checker=None - ).to("cuda") + "timbrooks/instruct-pix2pix", + torch_dtype=torch.float16, # , safety_checker=None + ).to(self.params.device) pipe.set_progress_bar_config(disable=True) pipe.unet.load_state_dict( - torch.load("./_ckpt/mgie_7b/unet.pt", map_location="cpu") + torch.load(PATH_UNET, map_location="cpu") ) # TO DEVICE? return pipe diff --git a/ml_mgie/ml_mgie/mgie_llava.py b/ml_mgie/ml_mgie/mgie_llava.py index 3b615dd..8733c85 100644 --- a/ml_mgie/ml_mgie/mgie_llava.py +++ b/ml_mgie/ml_mgie/mgie_llava.py @@ -4,35 +4,32 @@ # # modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/model/llava.py +import os from typing import List, Optional, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as F +from ml_mgie.base import DEFAULT_DEVICE from torch.nn import CrossEntropyLoss - from transformers import ( AutoConfig, AutoModelForCausalLM, + CLIPImageProcessor, + CLIPVisionModel, LlamaConfig, - LlamaModel, LlamaForCausalLM, - CLIPVisionModel, - CLIPImageProcessor, + LlamaModel, ) - from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) -import os, diffusers - DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" -REGISTER_NAME = "mgie-llava" +REGISTER_NAME = "llava" class LlavaConfig(LlamaConfig): @@ -44,6 +41,7 @@ class LlavaLlamaModel(LlamaModel): def __init__(self, config: LlamaConfig): super(LlavaLlamaModel, self).__init__(config) + self.to_device = DEFAULT_DEVICE if hasattr(config, "mm_vision_tower"): # HACK: for FSDP @@ -121,7 +119,6 @@ def forward( images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - # HACK: replace back original embeddings for LLaVA pretraining orig_embeds_params = getattr(self, "orig_embeds_params", None) # if orig_embeds_params is not None: @@ -345,7 +342,7 @@ def __init__(self, config): self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.edit_head = EditMapper() - + """ self.scheduler, self.vae, self.unet = [ diffusers.DDPMScheduler.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="scheduler" @@ -370,7 +367,7 @@ def __init__(self, config): conv.weight.zero_() conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight) self.unet.conv_in = conv - + """ # Initialize weights and apply final processing self.post_init() @@ -570,10 +567,11 @@ def initialize_vision_tokenizer( [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True ) self.resize_token_embeddings(len(tokenizer)) - vision_config.im_start_token, vision_config.im_end_token = ( - tokenizer.convert_tokens_to_ids( - [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] - ) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] ) if num_new_tokens > 0: @@ -621,5 +619,5 @@ def initialize_vision_tokenizer( )[0] -AutoConfig.register(REGISTER_NAME, LlavaConfig) -AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) +AutoConfig.register(REGISTER_NAME, LlavaConfig, exist_ok=True) +AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM, exist_ok=True) From 67a97148221607e81712868580e1ed0b0d4e586a Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Thu, 8 Feb 2024 23:23:26 +0100 Subject: [PATCH 04/23] add preco yaml --- .pre-commit-config.yaml | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..ed45766 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +repos: +- repo: https://github.com/myint/autoflake + rev: v1.4 + hooks: + - id: autoflake + name: autoflake + entry: autoflake --in-place --remove-all-unused-imports --ignore-init-module-imports + language: python + +- repo: https://github.com/ambv/black + rev: 23.1.0 + hooks: + - id: black + files: \.py$ + +- repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black"] + files: \.py$ + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-yaml + args: ['--unsafe'] + - id: end-of-file-fixer + files: \.py$ + - id: trailing-whitespace + files: \.py$ \ No newline at end of file From dc8017d9ec717313704a81e8168070ed7f2db44a Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Fri, 9 Feb 2024 10:27:04 +0000 Subject: [PATCH 05/23] remove llava --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 37345ae..15ee2e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "ml-mgie" packages = [ { include = "ml_mgie", from = "./ml_mgie" }, - { include = "llava", from = "./LLaVa" }, + # { include = "llava", from = "./LLaVa" }, ] version = "v0.0.0" description = "" From b913dfdef79b204f03a6541439c09989230e864b Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Fri, 9 Feb 2024 13:56:37 +0100 Subject: [PATCH 06/23] refacto to add processing in object --- ml_mgie/demo/inference.py | 37 ++++---------- ml_mgie/ml_mgie/mgie.py | 93 ++++++++++++++++++++++++++--------- ml_mgie/ml_mgie/mgie_llava.py | 10 ++-- pyproject.toml | 6 +++ 4 files changed, 88 insertions(+), 58 deletions(-) diff --git a/ml_mgie/demo/inference.py b/ml_mgie/demo/inference.py index fb5e7b5..5232eeb 100644 --- a/ml_mgie/demo/inference.py +++ b/ml_mgie/demo/inference.py @@ -3,15 +3,8 @@ from pathlib import Path import torch -from ml_mgie.llava_conversation import conv_templates -from ml_mgie.mgie import ( - DEFAULT_IM_END_TOKEN, - DEFAULT_IM_START_TOKEN, - DEFAULT_IMAGE_PATCH_TOKEN, - MGIE, - MGIEParams, -) -from ml_mgie.utils import crop_resize, remove_alter +from ml_mgie.mgie import MGIE, MGIEParams +from ml_mgie.utils import remove_alter from PIL import Image from tqdm import tqdm @@ -48,26 +41,12 @@ ] for i in tqdm(range(len(ins))): image_input_path = input_path / f"{i}.jpg" - img = crop_resize(Image.open(image_input_path).convert("RGB")) + image = Image.open(image_input_path).convert("RGB") instruction = ins[i] - img = mgie.image_processor.preprocess(img, return_tensors="pt")["pixel_values"][0] - prompt = f"what will this image be like if {instruction}" - prompt = ( - prompt - + "\n" - + DEFAULT_IM_START_TOKEN - + DEFAULT_IMAGE_PATCH_TOKEN * mgie.image_token_len - + DEFAULT_IM_END_TOKEN - ) - conv = conv_templates["vicuna_v1_1"].copy() - conv.append_message(conv.roles[0], prompt) - conv.append_message(conv.roles[1], None) - prompt = conv.get_prompt() - prompt_tokenized = mgie.tokenizer(prompt) - prompt_tensor_ids = torch.as_tensor(prompt_tokenized["input_ids"]) - mask = torch.as_tensor(prompt_tokenized["attention_mask"]) - + # Prepare inputs + img = mgie.prepare_img(image) + prompt_tensor_ids, mask = mgie.prepare_prompt_id_and_mask(instruction) with torch.inference_mode(): out = mgie.model.generate( prompt_tensor_ids.unsqueeze(dim=0).to(params.device), @@ -80,6 +59,10 @@ return_dict_in_generate=True, output_hidden_states=True, ) + import pdb + + pdb.set_trace() + # Here out is nonesense: "Pres flash togful calledgot At commitilli split sent" out, hid = ( out["sequences"][0].tolist(), torch.cat([x[-1] for x in out["hidden_states"]], dim=1)[0], diff --git a/ml_mgie/ml_mgie/mgie.py b/ml_mgie/ml_mgie/mgie.py index 6ff2a14..efd8e4b 100644 --- a/ml_mgie/ml_mgie/mgie.py +++ b/ml_mgie/ml_mgie/mgie.py @@ -1,12 +1,16 @@ from dataclasses import dataclass from pathlib import Path +from typing import Tuple import diffusers import torch import transformers +from PIL import Image from .base import DEFAULT_DEVICE +from .llava_conversation import conv_templates from .mgie_llava import LlavaLlamaForCausalLM +from .utils import crop_resize # from llava.conversation import conv_templates # from llava.model import * @@ -38,17 +42,22 @@ def __init__(self, params: MGIEParams = MGIEParams()) -> None: self.image_processor: transformers.CLIPImageProcessor = None self.image_token_len: int = None self.emb: torch.Tensor = None + self.pipe: diffusers.StableDiffusionInstructPix2PixPipeline = None self._set_model() - self.pipe = self._get_pipe() def _set_model(self): tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA) model = LlavaLlamaForCausalLM.from_pretrained( PATH_LLAVA.absolute(), - low_cpu_mem_usage=True, + # low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True, ).to(self.params.device) + model.model = model.model.to(self.params.device) + model.model.vision_tower[0] = model.get_vision_tower().to(self.params.device) + model.lm_head = model.lm_head.to(self.params.device) + model.edit_head = model.edit_head.to(self.params.device) + image_processor = transformers.CLIPImageProcessor.from_pretrained( model.config.mm_vision_tower, torch_dtype=torch.float16 ) @@ -68,8 +77,10 @@ def _set_model(self): special_tokens=True, ) model.resize_token_embeddings(len(tokenizer)) - ckpt = torch.load(PATH_MLLM, map_location="cpu") # TO DEVICE? - model.load_state_dict(ckpt, strict=False) + # ckpt = torch.load(PATH_MLLM, map_location=self.params.device) # TO DEVICE? + ckpt = torch.load(PATH_MLLM, map_location="cpu") + # incompatible_keys = model.load_state_dict(ckpt, strict=False, assign=True) + incompatible_keys = model.load_state_dict(ckpt, strict=False) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) @@ -78,14 +89,16 @@ def _set_model(self): [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True ) - vision_tower = model.get_model().vision_tower[0] - vision_tower = transformers.CLIPVisionModel.from_pretrained( - vision_tower.config._name_or_path, - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - ).to(self.params.device) - model.get_model().vision_tower[0] = vision_tower - vision_config = vision_tower.config + vision_tower = model.get_vision_tower() + vision_tower: transformers.CLIPVisionModel = ( + transformers.CLIPVisionModel.from_pretrained( + vision_tower.config._name_or_path, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).to(self.params.device) + ) + model.model.vision_tower[0] = vision_tower + vision_config: transformers.PretrainedConfig = vision_tower.config vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( [DEFAULT_IMAGE_PATCH_TOKEN] )[0] @@ -100,28 +113,60 @@ def _set_model(self): image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 # model = model.to(self.params.device) - model.eval() + _ = model.eval() emb = ckpt["emb"].to(self.params.device) with torch.inference_mode(): null = model.edit_head( - torch.zeros(1, 8, 4096, device=self.params.device, dtype=torch.float16), + torch.zeros(1, 8, 4096).half().to(self.params.device), emb, ) - print("NULL:", null.shape) - self.model = model - self.tokenizer = tokenizer - self.image_processor = image_processor - self.image_token_len = image_token_len - self.emb = emb - self.null = null - def _get_pipe(self) -> diffusers.StableDiffusionInstructPix2PixPipeline: pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained( "timbrooks/instruct-pix2pix", torch_dtype=torch.float16, # , safety_checker=None ).to(self.params.device) pipe.set_progress_bar_config(disable=True) + """pipe.unet.load_state_dict( + torch.load(PATH_UNET.absolute(), map_location=self.params.device), + assign=True, + strict=True, + ) # TO DEVICE?""" pipe.unet.load_state_dict( - torch.load(PATH_UNET, map_location="cpu") + torch.load(PATH_UNET.absolute(), map_location="cpu"), ) # TO DEVICE? - return pipe + + self.model = model + self.tokenizer = tokenizer + self.image_processor = image_processor + self.image_token_len = image_token_len + self.emb = emb + self.null = null + self.pipe = pipe + + def prepare_img(self, image: Image.Image) -> torch.Tensor: + """image: PIL.Image.Image, Pillow RGB image""" + img = crop_resize(image) + img = self.image_processor.preprocess(img, return_tensors="pt")["pixel_values"][ + 0 + ] + return img + + def prepare_prompt_id_and_mask( + self, instruction: str + ) -> Tuple[torch.Tensor, torch.Tensor]: + prompt = f"what will this image be like if {instruction}" + prompt = ( + prompt + + "\n" + + DEFAULT_IM_START_TOKEN + + DEFAULT_IMAGE_PATCH_TOKEN * self.image_token_len + + DEFAULT_IM_END_TOKEN + ) + conv = conv_templates["vicuna_v1_1"].copy() + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + prompt_tokenized = self.tokenizer(prompt) + prompt_tensor_ids = torch.as_tensor(prompt_tokenized["input_ids"]) + mask = torch.as_tensor(prompt_tokenized["attention_mask"]) + return prompt_tensor_ids, mask diff --git a/ml_mgie/ml_mgie/mgie_llava.py b/ml_mgie/ml_mgie/mgie_llava.py index 8733c85..aee22fe 100644 --- a/ml_mgie/ml_mgie/mgie_llava.py +++ b/ml_mgie/ml_mgie/mgie_llava.py @@ -371,15 +371,11 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - def get_model(self): - return self.model - def get_vision_tower(self): - return self.get_model().get_vision_tower() + return self.model.get_vision_tower() def get_vision_tower(self): - model = self.get_model() - vision_tower = model.vision_tower + vision_tower = self.model.vision_tower if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower @@ -589,7 +585,7 @@ def initialize_vision_tokenizer( output_embeddings[-num_new_tokens:] = output_embeddings_avg if tune_mm_mlp_adapter: - self.get_model().orig_embeds_params = [ + self.model.orig_embeds_params = [ self.get_input_embeddings().weight.data.clone().to(device=device) ] for p in self.get_input_embeddings().parameters(): diff --git a/pyproject.toml b/pyproject.toml index 15ee2e5..89b7231 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,3 +31,9 @@ diffusers = "^0.26.2" sentencepiece = "^0.1.99" protobuf = "^4.25.2" accelerate = "^0.26.1" +clip = {git = "https://github.com/openai/CLIP.git"} +evaluate = "^0.4.1" +tokenizers = "^0.15.1" +torchvision = "^0.17.0" +deepspeed = "^0.13.1" +ninja = "^1.11.1.1" From 3b942dc9f5b43caf581620415da224cf7506a17f Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Fri, 9 Feb 2024 14:06:12 +0100 Subject: [PATCH 07/23] add pure llava --- ml_mgie/ml_mgie/llava.py | 605 +++++++++++++++++++++++++++++++++++++++ ml_mgie/ml_mgie/mgie.py | 4 +- 2 files changed, 608 insertions(+), 1 deletion(-) create mode 100644 ml_mgie/ml_mgie/llava.py diff --git a/ml_mgie/ml_mgie/llava.py b/ml_mgie/ml_mgie/llava.py new file mode 100644 index 0000000..19262e6 --- /dev/null +++ b/ml_mgie/ml_mgie/llava.py @@ -0,0 +1,605 @@ +# PURE LLAVA FROM ML_MGIE EXCEPTS REGISTER EXIST OK +# modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/model/llava.py + +import os +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + CLIPImageProcessor, + CLIPVisionModel, + LlamaConfig, + LlamaForCausalLM, + LlamaModel, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) + +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" + + +class LlavaConfig(LlamaConfig): + model_type = "llava" + + +class LlavaLlamaModel(LlamaModel): + config_class = LlavaConfig + + def __init__(self, config: LlamaConfig): + super(LlavaLlamaModel, self).__init__(config) + + if hasattr(config, "mm_vision_tower"): + # HACK: for FSDP + self.vision_tower = [ + CLIPVisionModel.from_pretrained(config.mm_vision_tower) + ] + # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower) + + if hasattr(config, "use_mm_proj"): + self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size) + + def get_vision_tower(self): + vision_tower = getattr(self, "vision_tower", None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def initialize_vision_modules( + self, + vision_tower, + mm_vision_select_layer, + pretrain_mm_mlp_adapter=None, + fsdp=None, + ): + self.config.mm_vision_tower = vision_tower + + image_processor = CLIPImageProcessor.from_pretrained(vision_tower) + + if not hasattr(self, "vision_tower"): + vision_tower = CLIPVisionModel.from_pretrained(vision_tower) + else: + vision_tower = self.vision_tower[0] + vision_tower.requires_grad_(False) + + if fsdp is not None and len(fsdp) > 0: + self.vision_tower = [vision_tower] + else: + self.vision_tower = vision_tower + + vision_config = vision_tower.config + num_patches = (vision_config.image_size // vision_config.patch_size) ** 2 + + self.config.use_mm_proj = True + self.config.mm_hidden_size = vision_config.hidden_size + self.config.mm_vision_select_layer = mm_vision_select_layer + + if not hasattr(self, "mm_projector"): + self.mm_projector = nn.Linear( + vision_config.hidden_size, self.config.hidden_size + ) + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load( + pretrain_mm_mlp_adapter, map_location="cpu" + ) + self.mm_projector.load_state_dict( + {k.split(".")[-1]: v for k, v in mm_projector_weights.items()} + ) + + return dict( + image_processor=image_processor, + image_token_len=num_patches, + vision_config=vision_config, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + # HACK: replace back original embeddings for LLaVA pretraining + orig_embeds_params = getattr(self, "orig_embeds_params", None) + # if orig_embeds_params is not None: + # orig_embeds_params = orig_embeds_params[0] + # with torch.no_grad(): + # self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + vision_tower = self.get_vision_tower() + if ( + vision_tower is not None + and (input_ids.shape[1] != 1 or self.training) + and images is not None + ): + # TODO: this is a modified multimodal LLM -- Haotian Liu + with torch.no_grad(): + if type(images) is list: + # variable length images + image_features = [] + for image in images: + image_forward_out = vision_tower( + image.unsqueeze(0), output_hidden_states=True + ) + select_hidden_state_layer = getattr( + self.config, "mm_vision_select_layer", -1 + ) + select_hidden_state = image_forward_out.hidden_states[ + select_hidden_state_layer + ] + image_feature = select_hidden_state[:, 1:] + image_features.append(image_feature) + else: + image_forward_outs = vision_tower( + images.to(vision_tower.dtype), output_hidden_states=True + ) + select_hidden_state_layer = getattr( + self.config, "mm_vision_select_layer", -1 + ) + select_hidden_state = image_forward_outs.hidden_states[ + select_hidden_state_layer + ] + image_features = select_hidden_state[:, 1:].to(images.dtype) + if type(images) is list: + image_features = [ + self.mm_projector(image_feature)[0] + for image_feature in image_features + ] + else: + image_features = self.mm_projector(image_features) + dummy_image_features = torch.zeros( + 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + dummy_image_features = self.mm_projector(dummy_image_features) + + new_input_embeds = [] + cur_image_idx = 0 + for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): + if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0: + # multimodal LLM, but the current sample is not multimodal + cur_input_embeds = ( + cur_input_embeds + (0.0 * dummy_image_features).sum() + ) + new_input_embeds.append(cur_input_embeds) + cur_image_idx += 1 + continue + if vision_tower.config.use_im_start_end: + cur_image_features = image_features[cur_image_idx] + num_patches = cur_image_features.shape[0] + if (cur_input_ids == vision_tower.config.im_start_token).sum() != ( + cur_input_ids == vision_tower.config.im_end_token + ).sum(): + raise ValueError( + "The number of image start tokens and image end tokens should be the same." + ) + image_start_tokens = torch.where( + cur_input_ids == vision_tower.config.im_start_token + )[0] + for image_start_token_pos in image_start_tokens: + cur_image_features = image_features[cur_image_idx].to( + device=cur_input_embeds.device + ) + num_patches = cur_image_features.shape[0] + if ( + cur_input_ids[image_start_token_pos + num_patches + 1] + != vision_tower.config.im_end_token + ): + raise ValueError( + "The image end token should follow the image start token." + ) + if orig_embeds_params is not None: + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:image_start_token_pos].detach(), + cur_input_embeds[ + image_start_token_pos : image_start_token_pos + + 1 + ], + cur_image_features, + cur_input_embeds[ + image_start_token_pos + + num_patches + + 1 : image_start_token_pos + + num_patches + + 2 + ], + cur_input_embeds[ + image_start_token_pos + num_patches + 2 : + ].detach(), + ), + dim=0, + ) + else: + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[: image_start_token_pos + 1], + cur_image_features, + cur_input_embeds[ + image_start_token_pos + num_patches + 1 : + ], + ), + dim=0, + ) + cur_image_idx += 1 + new_input_embeds.append(cur_new_input_embeds) + else: + cur_image_features = image_features[cur_image_idx] + num_patches = cur_image_features.shape[0] + if ( + cur_input_ids == vision_tower.config.im_patch_token + ).sum() != num_patches: + raise ValueError( + "The number of image patch tokens should be the same as the number of image patches." + ) + masked_indices = torch.where( + cur_input_ids == vision_tower.config.im_patch_token + )[0] + mask_index_start = masked_indices[0] + if ( + masked_indices + != torch.arange( + mask_index_start, + mask_index_start + num_patches, + device=masked_indices.device, + dtype=masked_indices.dtype, + ) + ).any(): + raise ValueError( + "The image patch tokens should be consecutive." + ) + if orig_embeds_params is not None: + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:mask_index_start].detach(), + cur_image_features, + cur_input_embeds[ + mask_index_start + num_patches : + ].detach(), + ), + dim=0, + ) + else: + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:mask_index_start], + cur_image_features, + cur_input_embeds[mask_index_start + num_patches :], + ), + dim=0, + ) + new_input_embeds.append(cur_new_input_embeds) + cur_image_idx += 1 + inputs_embeds = torch.stack(new_input_embeds, dim=0) + + return super(LlavaLlamaModel, self).forward( + input_ids=None, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class EditMapper(nn.Module): + def __init__(self): + super().__init__() + + self.llm2hid = nn.Linear(4096, 512) + self.query = nn.Parameter(torch.randn(1, 77, 512)) + self.mapper = nn.Transformer( + batch_first=True, + norm_first=True, + d_model=512, + nhead=4, + num_encoder_layers=4, + num_decoder_layers=4, + dim_feedforward=2048, + dropout=0.0, + ) + self.hid2feat = nn.Linear(512, 768) + + def forward(self, llm, emb): + hid = self.llm2hid(llm + emb) + hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1)) + feat = self.hid2feat(hid) + + return feat + + +class LlavaLlamaForCausalLM(LlamaForCausalLM): + config_class = LlavaConfig + + def __init__(self, config): + super(LlamaForCausalLM, self).__init__(config) + self.model = LlavaLlamaModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.edit_head = EditMapper() + + """self.scheduler, self.vae, self.unet = [diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler'), + diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae'), + diffusers.UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet')] + self.vae.requires_grad_(False) + self.unet.register_to_config(in_channels=8) + with torch.no_grad(): + conv = torch.nn.Conv2d(8, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding) + conv.weight.zero_() + conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight) + self.unet.conv_in = conv""" + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + def get_vision_tower(self): + model = self.get_model() + vision_tower = model.vision_tower + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + p2p_inp=None, + p2p_ans=None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + images=images, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model/pipeline parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if labels is not None: + llm = [] + for i in range(labels.shape[0]): + try: + p = labels[i].data.cpu().tolist().index(32003) - 1 + except: + p = len(labels[i]) - 9 + p = min(len(hidden_states[i]) - 9, p) + llm.append(hidden_states[i][p : p + 8].unsqueeze(0)) + llm = torch.cat(llm, dim=0) + hid_edit = self.edit_head( + llm, + self.model.embed_tokens.weight[-8:] + .unsqueeze(dim=0) + .repeat(labels.shape[0], 1, 1), + ) + + B, DROP = labels.shape[0], 0.05 + + hid_null = self.edit_head( + torch.zeros(B, 8, 4096, device=labels.device), + self.model.embed_tokens.weight[-8:] + .unsqueeze(dim=0) + .repeat(labels.shape[0], 1, 1), + ) + + with torch.no_grad(): + lat_ans, lat_inp = ( + self.vae.encode(p2p_ans).latent_dist.sample() + * self.vae.config.scaling_factor, + self.vae.encode(p2p_inp).latent_dist.mode(), + ) + lat_ans, lat_inp = [ + torch.from_numpy(lat_ans.data.cpu().float().numpy()).to( + lat_ans.device + ), + torch.from_numpy(lat_inp.data.cpu().float().numpy()).to( + lat_inp.device + ), + ] + + noise = torch.randn_like(lat_ans) + ts = torch.randint( + 0, self.scheduler.config.num_train_timesteps, (B,), device=noise.device + ).long() + lat_noise = self.scheduler.add_noise(lat_ans, noise, ts) + + prob = torch.rand(B, device=lat_ans.device) + mask = (prob < (DROP * 2)).reshape(B, 1, 1) + hid_edit = torch.where(mask, hid_null, hid_edit) + mask = ( + 1.0 + - ( + (prob >= DROP).to(lat_inp.dtype) + * (prob < (DROP * 3)).to(lat_inp.dtype) + ) + ).reshape(B, 1, 1, 1) + lat_inp *= mask + + out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample + + loss_ce, loss_edit = loss, nn.functional.mse_loss( + out, noise, reduction="mean" + ) + if int(os.environ["LOCAL_RANK"]) == 0: + print("loss_ce:", loss_ce, "/", "loss_edit:", loss_edit) + loss = loss_ce + loss_edit * 0.5 + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "images": kwargs.get("images", None), + } + ) + return model_inputs + + def initialize_vision_tokenizer( + self, + mm_use_im_start_end, + tokenizer, + device, + tune_mm_mlp_adapter=False, + pretrain_mm_mlp_adapter=None, + ): + vision_config = self.get_vision_tower().config + vision_config.use_im_start_end = mm_use_im_start_end + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) + self.resize_token_embeddings(len(tokenizer)) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if tune_mm_mlp_adapter: + self.get_model().orig_embeds_params = [ + self.get_input_embeddings().weight.data.clone().to(device=device) + ] + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load( + pretrain_mm_mlp_adapter, map_location="cpu" + ) + embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[ + -num_new_tokens: + ] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError( + f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}." + ) + + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] + + +# AutoConfig.register("llava", LlavaConfig) +# AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) +AutoConfig.register("llava", LlavaConfig, exist_ok=True) +AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM, exist_ok=True) diff --git a/ml_mgie/ml_mgie/mgie.py b/ml_mgie/ml_mgie/mgie.py index efd8e4b..c664256 100644 --- a/ml_mgie/ml_mgie/mgie.py +++ b/ml_mgie/ml_mgie/mgie.py @@ -8,8 +8,10 @@ from PIL import Image from .base import DEFAULT_DEVICE + +# from .mgie_llava import LlavaLlamaForCausalLM +from .llava import LlavaLlamaForCausalLM from .llava_conversation import conv_templates -from .mgie_llava import LlavaLlamaForCausalLM from .utils import crop_resize # from llava.conversation import conv_templates From df7b0a281787a4536e9b09e3c3477828fae95566 Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Fri, 9 Feb 2024 14:06:51 +0100 Subject: [PATCH 08/23] add test image --- ml_mgie/tests/data/0.jpg | Bin 0 -> 64790 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 ml_mgie/tests/data/0.jpg diff --git a/ml_mgie/tests/data/0.jpg b/ml_mgie/tests/data/0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0e8692776ca5ada541ee4636fc17d580d3fa8a89 GIT binary patch literal 64790 zcmbrm30zG3|37|GL}+Z;O2e%ySBOic)EsV?Y{_1v2BA$xD(!P5%08A$SJ5B|ErT}N zrqyt}nP{1mnrhNi(`cq?rfKHPng2V@iZ-0gw`G0Exysa#PCXXEjje;~aeu73#*3g`+p)7;& zV4S0e+lEwsG)8G^jUJ;tcAU<5@IcliXq1Mg<|r-A(WAArz|#@nKBP5y^ptta){L3D z?SS@Ae$#%9ymo)A(QnU-^|rqk7%x9~;`BJ3=`&`|n*H;9lLfylTw!iuxzfto?)SCp z?ALGDxMSz8-A>MX_PQTBe8j`k%iBL7@MKVMNNCiVv(Yi<&R@8GxrW2AHS%&+x>4nAc=5uTffBnp)aJ^U@d<1paAG)*3x;*_bJ7wrL;moBGqQkz=R* zcJ2Q2;&DdHw+r+Ro_Mb_-FQXo&%&XpeVy6=*~Cu&U(IY}V&CS~08P-;0Gp>d8G@mH ziV<6^gm~N9!!WwSyXryVJ>lO<==)zLz8O#;vKkHQ@I%kl7i<9K(pYoCus>4_?dd!v zRCjN}kMH)`Dj_14<*d3w?zO5OcgI}ol_??d+}YQG{D z5ek&h=%>>15Xm_kt5(;kzcbO0i*%Am$`Ns7Oe$FPQ`f}1n#gbm;y5kD5R6(0?WO-Q zNN>r&XD+IH^ja1%rGXPA$=kFXL)LT^@j66Eu?ylwvuF#p@w>*85a?UR@f5b6vJYSA zG%Le72(?^z+4ndHY{@LLlTux@gcg``lIj%bBC$hUuMt3`($AI24wxS z_-@55N=Uk7Mw!pbG$qu!P>_E92?vW<{-|FG;gO@;#5bFe-B@p&ytF$d7e)TY;#9k$ z+Kdz-T!aW*m{ARX?B={COw4_+i+uhd5)kgv6@decQe0!H>D9EbCjx zAoU8muv<0Xz28YwjRxgA8U69DN>6@#w};G@iXo3d)m~qy>O;C>29;1@r&>Aw^WL!D zIH_+rsRVggNov#rb85LtJ=S`vf2Eo}n1M>cR89Qz0r>j|j{c(s+kY(SkfM$N!H*Lh zUOEk`S~^r~scv>;tp9d@nR+fGBS{%lE!L^!>P7nJH2!U#sUsa*s7Cs49jh1d_odi? z5F@hsZGHZ2op=A+q`#L4d^i1_KeeK|F`3Kkb~FBv+NMC2kk?B?I*qvR!YF12MsH;` zkIW@CWT~BePyxRSl_NH+3MHfmdsCu`<9v^q<-2q66Q#%4A9|g!4>d|?4i0wQ>@I|_ zDxq2Qmdph0YVvz`L4{pSzONrc<_7e5O4D*#LNoP4z8nY_v6>h{Q>mC-^d$I(T{VK( zD1y5%M`-}MQ6N@T&;jvi1mJznE8 z*H~UU7H-lirwyKwo)pO~b6^FTddsIRf>;)so{U|Kc-rLcWN9VepO?>+hACFIWy=2Q zRUj<{&u&?69|{jnm93X%0c5-`I2n)lxwlNjZ5Nts^9Tz0I&J{?{W7NxxXeuY3xjVi z)ZsbDY!vRwIg=^+P`zPJXuD$NTux!I!3C6=e{1X^CA5Zmg_|Sw5SdfC4!|nVg{_!g zN5nx_GXN=l_G@DgE09vNmlFI8lqjTJ=FDhbD0Dh+o}z@Zr#LjHZyB@v`f8IacbBfZ zYDKe%cP;pnYr*$kf3Z(6juF zvF1$8PogW)bXS5FQmhN0`Mma)TJl1oC!!s0`(%$I6Y+7qY**IqAO5Y>BlD zGc2r3Mv7r2bhL7_61tDQL#Da#T3-V=^`YEH{fN!tI~Eq<_4QrZpOjEV(yhzA{R7Op zWI5PAQ3F>*6eP-5$(g;BLG8L3v8_tzExs5~0)eAhqx2(3-eqjN}Yj%RYq*(eb9659|((xV`_Pm}q%dIgp$ zuMqT!`iLiRsb*%2>3?I+1U;+)&a^CXd2Aw}mHCo-+64F)uCCmuD;xO5yo4><)e?6} zpb|RWRBBnj0=|Ojn`riUlv#wg7@Y8{e%9V@5b%aJN~%yoTJsc-$OiV0z1iae9cFN* zM|vrtCE1$;nA1G2Z+NfoXL3s0D6-Lt_L`y}ZX=|ZUiK#8g1IF!)>-q5{vv0d%&zOj zm4curT6d1``ZM62?k{HP2hze`<^>h<`m1|$sA?_4q!ndxh z=xOptB{UT#CId7v`cRPgwvUcSL@YxE42!CzhlNpXW+f9(;wDSA%dKN$%=KN8@+d6rvGMTccH_?+2y};mk}2KmH4arSUsy>y-e!4s%++C z&0!Wb!aUUX*M4I(=kF^a6BM&Iya%|=Fn<}Dh4et*Pw-W^f+jsFPW-wdYNEeD)pD(x zpAD_JVLWJroA=7>G^0CV5RbhsAt+0c13&JWuuD6pl!k7EVz4DLS zxTH2t(vHcWvM%xZdQI!Dxw?MW*F)~cPSk+hhGv8n5P(mS-J&OX(S?c#k}m)vsz^-3 zNzcmH25Yu<#QHszQp!S2qBG^h!HdB$&I2$4}?bf+{@Xn}SYt`w59+8lms=_4Kcc#40T@MYX< z4Q*j+VWu1|MK5FN0QdUfB=!}ez?pd(!^Lj^~A1t^3j+IqAPJ&xkhCwDJEGN_jUvZ)Ms~r=8R1 zy_yx^7MXBkt8MtW~4@i#f&v=G8H!MC7#+ z;v%++6)(JZ)m2J9^6S`>(VlQ!#h(OA?7g*^9z&QCJbsVRc);;i&F;6H4?U7=(Knn; z+aq7=Jb7cI*Ila^Oa7ZESVJ4{brG5X8>*oC)_Cu-3fkJ~3vc zNf)sY(-hz%v6-IM$XNiSPIdyvn0|V{sjX3Iox1}U?MI)x+RNCYOYS1^0BlnR#K1AgGlYjp`*g?uj>&Y%)?j_Gioe)BU^s6%j@Q zqR<5Lnd=BMnHc&#AoOf(oOxTUR>|+k?dqBRYT$aR#x^xNsBb)0sV?O4UK?S^)W3d8 zS7D*%Z_GV~8L>GpIVfC`PXxxZhau3GqzhgNxUbD1&$!P;8A`}bI5CMT@i?CeynsFf zC8Q%(A;7M4NMJ8ZjM9^|*8pj5#0%D%T|_h!Yf+M5VHo8ij7zg~bL;X>#HK`OO}+}X z2DV;9bsio&Xx9(N1l_WfnfFVAvIjyV6p1}AvvYSpE6gP4t&e!PY=Z8x>$pz_=E$S9 zx02hxI^`5Hp^=RaU|EBc0dED&{dRI^i#8f-qJw>X`kQv@@Wl!@AivCOXmtfM?)AswIH8YL0I9FG< zP*dm!%^HIk_!q>oIo7h*65cj}ZEZ8(y_9i(R!e$_!k=-Vs`Oy!$!cPri=sgZ-Api< zh%hnz-!qA(5llH%wYPd)C8SBV2*yj!N+S}{b%HS|zE8YlX|b}yNl9EZIIwF_F1(=L zGN-gxMgObSbg%uV<&Z5@3)wfB8kWZ&jf{?3UVjK7&BBJGQ{UU25tH{n-&I0i0|RP- z{#RS17VVUO^BnEJ1{8iA^900XfX=lgzvQv|%T=~V9l_Yv{w>0GZ-OeEvF*Rl9mwVP zK|9rAO|`=u2hl6GYo`56(BfO@?)whvRSl@d1p5E#`>W|ewrcryVGO2;ydz5jJsWV~eTUw;#Qk zp0G%GxMF%SFXy+S9h7x}@BFU4#h=-mkoKgT*JMO!rgl4wm?koVtMNX&7dI~xB7CgAHdI<+xEs;p@H}xqw?Qm&5c~_OE z;hegBr>D-HkMlfm$>!SDu|kYogclu(5$c zQrn$e(j~f=DUWidaO}%*p@n+&IoK*zK=~{y9lv{hWYL$CySAU~Gj%*nU`HvTd9Fc3 z7HTbLmODmM#}yg@PgO7ZEMa!=sxrcIptsK+_&Fy$~N8|c&2-1v7NFK;*fqU2j7m1gkfPR*&E5$nq znhI!^akP3R)b@xDiw#)IqWZ;rx7K#1C9tGIB?L1K{9WtAm$T_^R6!Gx7S6qxR}jk? zZ0yT% zbulW)r|mI4x^0+Lk3?ohNTjyazXZ+HF{SS;kUHS@;2$BS-!6;<^wi3vqb8W&sOy3s zwdo(iE@M@bmMr?Lgp!fu>P~CmXNh;J8ch5vzPcaZx*C56FW7-d=L5fwev}v;i~TCF zxcR7yZeX)GQITg~oBB?-fH~VDJHnoJP=iM-C0;}qr%O%BLK$bn?UP(RbFs6*M*}Y} z>DA~l*ZX|+)Oz_#qjc+0_L>Td{xIhJEo=`SZ*YJo6%yYu&>7m;IF9!D75bD0R;x*j zZ*;*UV2#F`Z1`|uEf)I*YverHxlqux5>?x_f=9_pIsWG!ZH+ic?e4hj(ZW(glyG_aIVVEA~+y3VjM=Mkn=8}Uxv?76ZJjE2Glw~WS zQH1&3v6hB7p+kM9ptiP_loc?r)zmkfn#m*Y zsCR2vfDAhR*I?%l5Bw|8|HBfbMo=1+kl{aM>g({-<5~URjnH#-q<-s|ANoJ)fE$Fl zeJjOUyMhk@t>lVVN=E%d-tE2yl?@$lO65u@;pITr;HeSM^M7?s)nw|2>;M1A@&EZ+ z{TKBdvBUqHrPUrxe7{z&kNa%eHsS*%^n!#=wUba4X=X>w!WXy5X*$6gr)_)7aC&T2E%ElS7e%{Lz`L6G{2g`6~ttTuy}?3z>U-9Al;} zjzz5_)r+uk6Xk2k)i=7DjfK&-{RK#?LzUb=IInECv3HQ{#jVA9)52PuOEdX&t{AOPCExgqm*w~*NVe3?ef5GS(C%3@f z)Jpsz7eb>3sd8%~*2DrYkW5n`2?w)&rR0tv(XS-dL+N@%l+g6N)Bvw3wl zsv9^Pb>n&}CZK8_Y0{f1-{TsLXzsP=!4jti@_eSXg3e%r{ zTIBE{e^;W9p#f`=Z`NY`&J3(PKQm7W{TUt0Bq3d$rqr~HR{2y0>ezP=*OG(jT$Imn z8`}I#-HZcCAI)=0Nwb=}tyz)j)BFz4JjhJEj@wysv*ph8pZ{2b>&Qytciavlg}Ix$ zd}{clNU!T=!$w!S?dOhyAju-!Omi#}5$x#wIGrKMRK$g^Nph}lL`sSTSyTx->I*=a zyOdTUcOcbM7Jc@q_iV!_nl9e655LdI&w96!efck^XGO)#HIYSJy_U8WH~A78kB(|q zmC^&3R3L6UWIX_ye*^gc+gxdJt0@wRpBhM%zMMVrHpa-yYw4a19Df^1x6i<7Pmh?# zOdf{ZMB-lOgQNOSxnu-h? zQ~Cbzxz%E3j*Cn9>MCDCd7Fa?@cI*$b-7*9pL)2-H)sHfb>*jJ?l@PPlb&)UpXQsT zhzNt#R8kdbN;&9>yah^ejq&)W6qZ-Xk9fB0!aoeUx`kaGF(_3uXb^gO!KebgABDwK z9kWZRE33kI0*qKKw)M^jF?4(v2z;r1R5iW>*Y!VmMyHAn|DQvhO65k-bi_fh8xDV| zhNq_GD)(FMA52qc+f|ydS3MmK%GbDL%HWPraujNM@B;KN`A^_I@K8=N%XSrZ`PcqwFyy+}dnD`1D_?g(a*^z8-BlOMgvGUXUd^+f zLw7>rDxzk+BX4~e`Kh1Q87nadvNbtayNKZ^WbzNBxCASqvgy>W6oIHTKiJ*z?c3RP zSp$Ks*MqlC^Ej&~uWY4Hl#+^PTTokx+QI(63u0q3^Wf)ux(J3DAiQR49SXMJyRY-` zm%VSBFo(ukAh!1Rgp<2S@x+d7*pTkdl>WphGCOrSQCh*-DPijqj2ha6N7K{vw+PH) z#LEVDdS3avdG4;BcOOef-#uG+Gs(HpGmadKo%Qf|2qMx)y0W#CIhgHBwF9KS zrYoVR>Fhdcb05*~tvKl;8E#i+VV%{_Bz`r!-p*=oNDZ%QuB#FjZL8q`$~WezT5y%9 zKepESAMq{fjtV$cuc`M-eH|%|SgS^H6};a2q22o+N?VDG#}I1H7O~q*ggL0#VtZ=D z&m18qPsR7YQK{dT4h1M>cfvGQ&k_WoNYV&)oBV}FC?V#)uzr-OO0xg&*=A(e4;WDc z_qPhxAj2^|G)S@hA1o=wa3_HZ!pS;neNb-!baYQ8_xJq?8vec*f@5|O%kk)Rxnxzg zR6{6XwDq2hh5?G5hSvL|u3fbccbaYz$deiEGtDj}roiv4 z3*Ra(g4~wbO!6D*%oQYjr_=Jg22G2p=L3MpW7crs_0E1PHz`p={u+;1&PA|f`a`iHb+RrDcD=$)Ww+vZnAD}D{^j0 z`6JrHrlxC;K5D)_w)%~U(O{6dA`dujXIpwNy3Kfx`n5#tZruJ}ya@(0{Q*)Dlbkci zjgvg9kN3)LB=V|hzP>r}=Sl;6GkBjHU3?TiRcbj;wZnm^HT;C}Oz*il^+qrF&qkGq z@|T&`hWrl+02$3y`-Yq*x`AHz`rD5dLv4vhY|e<^d}>)YV_2P1e$$6V=pbe&G7-?YGZ69HZ>Ui-} z#!buQPPpWoJMaU5rHt4=)8A66{BL%uC#POJw+kwPSOBTOcG|Yp+TTOW)ADH=KLBv* zw$*xSQ||oWlqqYChb0=`$i!}tfDcDYu!Xm*sLb32- zX`FMmz`i`mhnL}fd{Ow8X}5M@eOGi(1iNjhB{+nyW#_c?@xdx8bZ^?;WRtk>!5(i@V4Lf@FEeD|=xC*c+~ zU&^U6S8(Ti=$#4F)hDc0@*>~TUMrH(Yy93Xu(bu1^fy9ziSmBl?cKoVVIFI+IN)|; z^Mak)lFwu!vw~Oc%C53)h}9vN5<(#H$fiA|8h0ggG$;<6xqrkaCuk_LWFm(mU@UKOwZwm7&t3 z_I19~n{VyytOmBVE*T&VQWO!M7uuD~Q|&3(^Fa{uAF7DU+x+e4%mpKM?ZJpDq*Q5q z7aQ6=Ptlu0tQKKeD#ibiNU4NYElB@8?(|eNB;U}FQa^13CH^7tfRz$1R8ezU6`sS= z1AmVz|1tRCu4@5xQ)!^;QZ+IrwchO=Nm>FIe+Css!+_cS$Y&ZLvd5!T-j`aZz5(aA zc!6pIcVB|L?SZ4Y&{1Wr)r$U~98s-dKr2c|Q2yY^I)lM}gSg#l-{~oRSdmZRzVqis z!WwEI*g86R6$XDusl@dnc&-9jQ?MAf)PgLmQr`t>7I^QwY%yKIe(u@BVlx@+`an#8};;dg+Bl6&?UnO)C=pDx@}lk zMg};7d8!*#Py1Dvq0UF&-_<#?)aBE*j>(5T@-MWFX zTv%~&o-c@fUa-TaxGeMW-nqy7A%*|^wRW@4>VwnvABu`1HXU6~232MVtACDZ9}ydG zCSQbaz)AMj1vZ&%%6TcIl%5KANNgUp2K&oME8vUn*e~O>4E@W`H-%#7SmR{~pAEmk z#&xDFFJAHMpGUUOf78qB{nSp*+M91zy_m5=ILRK(ONcrBtKSu*6s{g%VdLyU-fAuj zWGE?Hq)r@#6W?t{tr>ygA<6qY5K34j?g@FlZY=_f-%CR!u9iEq=s5`;<1hLoSv+E; z_ib4hv5Zm?*HzjpQXJHlo#>RkyV>TT?^xIWgnukRvAu$y9Ym%x`MNl$tIGU8efs1H3%Nl-k$@{uv7DarpoC!crh%W_L0phBxfZYT|7y<&0j7GzP zbt@K@Ee~0y2})hXCqV^t)gH=<<)t+q;5Pi%g{t_0!Y0^}DXUdik%D+Z8Bn1c7-B@( zI>NWV)ibLlU47j;CPE#|kkTa=0C+}{|H8lEr!mk)>mNuXU>fK5%g(iv;I%4D3 z%GaYl*6(8E`Y4VxHh7;0UJpI?>(PCa;B5V$!6li7H8lJ1fc|JM5+?<>a}&+@jddDN;fi9+CMQgaZsMeN)rCO(4%O;qDIODR76U5PD-|bcy@Ye9%b9 zmfTWg(LG%@UEU$NE|n)@f8!^RMky|~ez_66f;oll^=e+MT)VUt-QIG^mtgAo#%j=U z41HxFrJ-?ORe+qXi;fkB}27gl2HXVry!&D(I8X{D1$8vP81U%=b@ zWvJSLvx{Z!V85qk$(1E$=9?HbW{r3u%TGvO#>e-FIf= zdu~?SedqiAcDgr;Vp55Mq88?6c_4b>`3Mc&y&#b(QNXf(Ck@iH)nq zFur!tT`+P_^~@=29d%#(fS!IGhunjXty-;bh~#vEOwLnm_?j9*gPDDtK?ZeGvr}_JEDxo_@G@ns*(+NN_UivDEED zPg3%^w4?PYG~#W0EA1=Vuo-ljQ9@f&-h`C)qdjz-qaDu7&*Z(?TZH=gzQ8rmZDzGc z4Q+92^iS+ApBu~pdwh(jKlRW4#;zlMyWh2Z0%?oHi`#L?zRAqTve}Be046+;;#D?? zkjt12U&!CkT9Ia1)Ud1Lc34ngkXO0D_|61LyG-=%g!gX44aw;Vbo6``8I(^H&%!2n z7mkx|dK{dxE2@XxKoxQm66oCIdz@r`44MUYsyUrlOwg<<)R!JBEpFC(o}FFYpSSF= z%%jt`Amy3Lzy>9~aErNpzN<_2puqXTh5iXq-_saCp#KZY25fvSq^Gjh-+0aQwbP(& zDo}p>XUMW~T&lv51(LQmp=nM4K{8{|X~1_^#2!pwFyN`a`5W3U?!!rcBjDZ3;$pk~ zT6??1G*BaPfhyQbYXB%3yH!Bz2@+Uyi>+;0GCk^L5EgBFC^V13x-on&_1#M>1CWcB z`Se|`EDq+;FXWKUxj|A=;=W)C16w9wpUDg(O3~VUyHw8xFHKY~BjYc_nIv?^CQxST z9}>|DM-uRLWLxQ{XU{VumRj6xJybic_y}vVE4n(-`0lyND;FHK2b-lSk3cmyZ-G+H$@nWxNv7#hAni zsHLE}q!VpkQA=-oAQX2RYvk;Spw^$w<%2So4%4^}S+xqSK&2M#n#^3tlz{pH*i*5j zP2Of%)-*%1Qwj&Z{orc`x9ZpQX68%YFIPh2;BzOe_+4KRD>H%-nMuW^mk-2*FMk^D zCw!t$xq$t=??&r=vWwq=`_sp)b$jYsv~u6eBBOTtl;E6lL$ex>S{ui7@6a!aCxA_O zm0xRSIM|d+S~W15cR8`!LFFe(jW8|3^T<-fq&Sru5N z$`*_ohB`G$stCPj#8-4ugW1#mt)o?7`+gXxXe-!2C#i2dYHsWTIQtcSS5F@qg(t9< z!{GdW^>j7n?dte8NHz0N1K-G-K7m5!A6>QB9T8QD<|5B@-;?G3kSPR7D=< z4{-h-ZhbUKgdy9g$GrhTB63j?D=WmN*?7{mep?){ML9|Cm2K(G$Rlmh92j5GaEfp$ z#_cq2ki=&9GiNCwtBzOT{DLP2Nn;N9B?U7|<4n#0gE-hjebRGjlz%R+?P$08h%NTD%lWDrwYl+dXoA_I@}d0+yl0~?1&7t$vJnM4@9L&ei~ zPRb%LvGP4-j7U^Fxv`V1Th-q*)e(SPti${Yx*vutc{xZXo!gt9WCk2=CFCUCPqsC^ zS9PE;{|{lct%6Rt$-S#xH~Uz@n4}rp>X{t$Mwjbo(DK=Uk3FrJD|WlrEmEDDb-L$& zR!!y9fgioj;Rp!e1sC8z?Bx43Gw;SGcb`^AB~;I?IUK6`$r&>inw$-vn)JyhVVQ77 zjE!aDUO+@UQrhOYud5DVir73ByAS?K6gk*y>S#iBQn2BGiy!t>t!aN&CDWXTqZqE(QS)XCO-z1`xVBxw#y;sway-=|1$CFtRON8-6Eu*)3}mXnXiO17`+!Q=oxUy_$vyZHqkUQ42Eh0SfTJ)!fQ;zE{>tkJjN$eP8!o8OQRu4&FREmsWqzRGdO0U$0 zui?b^>FTC3O)B)yYU|4;Y4p>Wr1S7>>=$%@-4b*9In1ICb$PsxdROdH@{ze2cXPEt z!Yjhd^v&x`e8WE!-klv`7#D$s_J@uB7|~2^LbIB@BdCwN@y}RAcj^11%8)F_vfy!f zd3m)oTnu5hvvQ1di(Ev*i&ug1oWlW)c9 z93Dxp@EnO}kn!m!olfGH9IK>DagJ^Z>%`c}=Rjpj^xlT%-Gvt)6@=%W2MvNMbvwetINCSl z_N47>U9|1LfIsm5%bkXe&JRp@h{UTfT!pC6b0xGGVGcY7c4f?2)mu&vPT^E0W7x!) z@TJIJ+G?$81UQrPsM{UX?e=fg_|GSXe_jBMSUQ!Zsv3QPHhgbk0rWKv>y)CW#qM^j zLKZ@4|78llE9f{L@S{gO1-1E0QBgK1sWz>HU&FM?Mdo41C^HSJ!x3{qlsB87A`8!c#W7n&FJ>lj)()x`=Qm`vno5?u6;U zfn#0?{mTi)Qb$2ar+!7A<8+>OvEj_>{u2Aj`jH39l_07vD$VS-)&C5497x!;(0ixPzkk1lm8%tISY1D8e%w=nau_yJxWaT zXyXVQv1yyQg~+|_p&5$L4vF;EakMHB>t8$hjt(H=&C%j zZC}x}ZEK;g|C}7}1^e$4x;;5?KiJ>C#{91P4&=p0uf4kxy=PzMi3cZ7%kF6z_0-~& zwMLvGPeB1297WCKGHz$~PC>mbKrFw|plL54R}nd7tMOC2t-Y*F*1p%>blra-=Y-?t zJN*jR0vZUS_01*gd&vT^jG!^pq>rE2-7_2|VVg`>tt%|Z^QYv^9%H4c`HM*7wZ#)} z#a+b>q$Qe~mBZJbezT2oF4vb1n8v`#61klC`XTMS`JAQGD?CHTPwc&JXJoaVzI5UE zaC$AVrvL~cUvZW0p64^35{>8vFsDm9nyWrGhi?{UT@}taPdeYfR;O@|VdQFPKQuam z9`c0fX7Fi?{h$P+f5?3{80Kr4o5SYjg~__^SBlRKlJOnnaPr?&!6s(CNV54kAW*H- z1!HalX+LL8w4_cACxOFha~$P6X2JErP;jvB%9)G<9}1OFe*VjRkaNAF&5{LLb;&wX zYD2|SHvH^M--KM3UmB%`;O7(ATog+rv{YMF>MVFV5(TH)~l z3>z!jZeXtT+{o-0aS8~dtI7dJa-6$D)!7ggW8ST@lj?9ofwMX(^WegeQDim>uq?Md zz?f89f$%BiuTv`)mb;f)4gwGDG4rLl`*2uJ=LJa0f*wrIk-KLL(Gl_YYA&hFbBl`KSw6>(+VMnGEzN@-)*nc*^Nu?(bFWA>DDt ze1q$2$qT|g0A_Tdp3#5UjHV{hK&(XNCFT#P47XaS_> z45%7DoA0h0`xqQQ*#Ju4%_7}!!2kO#Bmcu95ED9L>mP?jU>3kGPsc2HvY2z&e1KtP z2}^30&_)k5h<*1{d041N!HKqIZe--CND&BPL1~}qu}a7r**rTIoi-jD=PTFl3NqAJ zTmj`ml~dd7FD=$a|1g`rM+v>L+d0QPRn!(v-X-moJ0Zmf=sxTQf4SYp+Hl0bg{s}; z@6iftA!Rw6nFpjYEDLo0B=~7a?v@h>73R)Feg2H!o^VX6K$tX^SX8K@_g%6)5sC~> zr3vwu%(CKBF7kN1`1p(T!LU4|ZJhcW3>Zt#;!fRogh2XjrI-gSpTTyO%B#rXhk2P8 z3nknF1>i@VfTAqZS@7VLj%g_4q= z^PdI%Hi|Ra?ilwzeklel;SoovrK!s;2ct4Li&cJkd5Lw{kuRbmJnmzb>;B{c-CW08 z7A@g#r&!)+xKUGb!0O|%-<@1pVi4mh6b;@H*q16H{cwBMegW%LzG)k7I&a|jcvq9v zA3AD!A2N&vww|hJyb*b)jed^^j!sMzfTJNBD>e#a`3{`P0_|eft&}j2)=jpyHg@0> z0XtXS+}pGvW{p-~$(jW#XIg3Hh?mO=c%Dm6LDtDG-#|XHxL_MM-&JPj zO8OkO4dhJqJlad1C(XeYI!^7@y1rp0r?g#m z-1;OfrPutI-k%ld0HKh>O3s0aB^lphLix=IOaU;;lV>$PtC%{%u8I5v@bwkG(}Sox zDn?1$J+~JK%0?CM9!h|Jbtk@M4*p`ON*v1Od@WnVUjpWf-tvn5oE?R|*FL?7Y#hGr zWmfl!yLfDJ+H$};&cw*vWlp6SF!WJpZ2T-6v`eWUD@u0PvHLy;wK1poSW{X74^u+O zW`)lnBc8KQ!0jREyeW-Liz zf->E^n}CCe4IQlj5cgz^e=mhcTrRzV=}ZQB zBpu@D+ePO%bo*w5c>^eVJt2R8x$J(V|RUz!n}%f>mW%oQu0}P5q$1} zjo!ew?qfTX5x`@N@w-8ELB6BQ*IjZ(k(s{%9Ht>&BI2{mWO89g2KX{X@=+&hvkx1e z^-H+CcQU%ShR)-a#WAf=>jAJIPJaN`vSp}UJ3U6%4m)8+o26o+6VNd#m-c8hUe!y0 z!gc7J(fN7@T=@n%%UOzXt1*+=OqzTDoMqJpyThwchZyXD8=UE&rdPiQ0|3_}>&_)93OZ#L zTVTZ=Tjb`+=}7iKndb>oKRnvo2ulE;BLE*3vF?sNNgUGP4A={vE;ZrXVMmw~2hd!QQGI@aph93m%4aDeNv<)}1Nw5VaK)Yw zydWOckdo3G!8bWfo$CUae;)KRlZg#xtp%0kEc8`( z14D8wdGG`{v4t2Yp-J?XSR`wp^y4wO)ku+0Xzb$BK$Xff1~|%!?y_cCg067q zIz&gxDRKK5#4jDsukhVONC4Q^zWa>CzZ`h^FmtdU<)Q3;edYzSZmtNO(?**vC9&i~ z>)lLcqCVX;4ZsZ6(Mj;x6KhEK(Fh?qX<|C`A>X;Q?1J`jr!t)lZ%nk7G716|UO{MX zV?`x7QXB}jvd?7T?r@@)MtnmrWlPN%d}1p-qnHC6pKuH5HIffl)mX%xpkJd{x)irn z*BX2tr&YklbakZoC*y+&i_z)AoYVGbp{tZE!lpPZZ)E4J?3Er?98n-NRO_>f#f2{z z4ibgq!Mi&5<)g#5M6>0loWk?1i+|%`2MCrK0E|iKvLYH!3&#~&gz{~Abcws)hi>ihjUZkqwpL$RWvq{Dmr8y%P-kfcWux!CGiVmA^@WdPnT zRx|(viey9iB>9@)*PLQ(68%H}tNy&Tc|t>4>ZCDeA7zZ*g_W6He^!$tg z4nVz}qR;r)G)}%EyT1=hdfpPA_#1FBorN1mW zk{B4V*UovJg!O&{)zi)cxJ00K#;}&c!OaZDw-#Q8&!NcB=ao!KC3N9H3i#%NH7R)z zrFf&-(j0&|zbf(y=S9#rDiV+*#3|vH)r0aG=Z>NH)NZ%=d<=PKT8T{oaybuaNmN2P z9l(7Zr-b;kIF~mS6SUkZ0g6gN$Ek+Rw@?zp(BVs!v}0u*-l_#x?;&{N}h zk=1a*Rup*cr|F=7$G|nM?5ILLAt{oe-<@N8ttEbWD|Nz>xQ`4;NAB1oI!l|PIhXYI zL+EkSe(YK=6W% zy>6^u3$DHwEHI`UOsHdmFichw72OPw6Qff0a?1|pqe^HOSH>IYPqp6Q8|WL_)%UzM zda@aNkm7Z)*PJnR#~sA%GyTYfNk*-CBP0LNJ*mq=JH6*$8Ll+{NQMoFGXtA!PNK+?k~c!=8@ z*Bqk}8bc%%Oj2J6?Y7u#iX*dECIR1tidadr>CIQk&1x>>=fVZGxhI? zWxPmcp<^{mRJniCNf%H!`soUstc22^?TT-J0UtUVa6r(tc<*2VT^|Dn(qv(kVee2G zSTYf%RCeT5SjF28Cd&jxn5+hsFB{9@wX4&SHNS4UPL6>nw9j2EVaL<=uA_p_XR;8B zk_)CD4j+2ODbf`M^Z(*)ty$78w@tcvGll5Zw9Mw{WeqFMQ(F)3+&^K|29HJ=7Ps_Q zX9Ha9%$X)e38!xpUl3m;51E<@#lxSU)n)Z+&wy9hN;-bCid>28iAL>DL-->T8+$`OnarG?8#r)rg@YZ3OW|)@LITN@Rp;6S(5HThy&IqzI$aYz%+WQG|KZnJC(W7=y!X=+qcVwt|+Pc8ok3 zU0kPMs&ykws9m-(+_ohGt|sR!3^`0W7#Xv#g#&*G>H$V{j1od8NgczIr7X!~;c4_B zE6sPhG~D*XG_U4y)#9~U8fkwuK4evmf#>?oT6r+RiAtmh)oNJ%K+@*_nQN-4>l9~Y9OKmnUG(L}N63pCe9mJvfIS|=2? z#F3iFc&3b;nB=6ctW;Mas;bezL8eiv^2z_Gs|VE+UlWR|vI%u{=i^ZYU#pTAmQNZw ziUA?FE;+45KA*GdUZY)2{Y)*o^75e^5a0wO5uOXmeC2Ni2&%t9KB{X{Lj}j|0$xAg zbtQ;r-S-Mh4EE>0YUdkpjPZGqyf~Rl30<%Tmi}xXjWIzr8X!*>?Gfqv{E$=;gG`M@ zN?y`y|8|p9Z=qC(#9@g+ckRJJwtIB@4&<(OLp&P;(6F9uJq4L9;Jhp=;4BExAJ5HHi^k=jl?Zi?^7{0!@W!JO*JP-qU zLdLf?vRU_!SHGYiuC8XXWP7QNemxlQi~0!k4=nU~;YYtt7{&_B?pmS(g*I@2h}{5w zCwR%nE1~xfQ5`_7#@j!Pmn>JrbXQqRUlb zUX|Y$(tfQxe63I$8+-r%BkSD*V#@#j@hDo?*d(_aHe`j|OKJ|=l0}BK6{1<}s-#9j z_c_rjx6INKMT0h>%Sa`E^xrKpZ=!SaxzFGZ^V<;}NdPEe(*(8WA0iIOXFXA3 zYrE_&MA&JzGkHXOm#eqyEHXl=GD{LMnlEV1VnK`Kn5ctQK^dR4^TM4B%-pm7$zWjN z(ArI}r$DE>-Tm2?}}g>XO3^XTL6S8IGy;ENB{?~rrM-Aq zIKF5<6AY)1d-V`6FM>+kD?#v8*ILO*l#5hO{F1`y+VI?3AfP>UBD7+4MFeX1g4<%=QT>47ID=BbF*rQ+??WHz z_`B=FAv~t$4H>`SVoMNwicRN5*ini|HFs*ChN~Q-@Mq(56TQLa_S(oyj9}e>$HrLf zKfttSR)Xx>7`Yj*Q11U5wj9~S;F&qs^Bu~&l?H3N`o;qtT2slUD-$tp zR~kOaThok%PrVZf#fHkC74wHl(Y8!W*xP@^E2YT@udbjZuN!V`a81eQ?34KOv(hM} zzGjf4e2bVJrcXJMjMU)P-te-l4?G87sj`61YN`WT8qD=!SnQ+yjZcE0L7w1ZHE__q zAaCE!ozpFb>wnnVfm$S1O7nzpt_O2Jf^9X@e*eD>33EQ$TzLT+*%&U}8 zayU9$d{cEQOBbGFn9tzJao+=h4xvsF+)geBu+mT_g?K=+h1tOasR~-B7&o0S+>7n* zs-#aQ1(aOx7sWO&!?+a}Dr?!;=2gT@eb9PJYbZREt$&(7SakB#O>bCQA{IT2o=%=C zN8%tI-YKHQuCiDnZ^w@m0!K$gP8G&b=UuIR`rg9DXkG({>bwbZ{wMUrHV<_!z39peN-^alg%1UL;Tn(d%It`zSU62cT;3i8RbkRma z?14r4{q4h@>dL~mof4eLWc5g=T17&S!Q?{i`9QezJv2~)cn$DR+gVICwO3} z%l6FB7LD*pgugsEnmX1u4)$fY-LuuU4=g7gtT$ZMnY;|P-g&fi8M2=vok?j9?||56 zF7A?7)QEI-VC{VJTbe0)IPjyvGXA6!xSmAeh~^1!GXW$G=>h6-79PAA34at(PJA?2 zJ`~GAa*kI=79*B65A#hr1HaGMm*{J=T3xX3K)uIf?6>4=1+HoJG3ATw^bcL1Z|WDV5X86JI7k)yG_6^%EAJ;UoOkaYd&sY|%Ht z8^|^dd8p(UaJW%`G}S}?GZu?h%+$X;;V6?&`X z{IerEqmZTJkA;oEPobq5!Y9(9QrlK9ktR&x{tWIB;w8v#Dx$8{{iUpAEh7@OQ+>Iw0a5*v|6qQJa$NAtLdi@*6M3SUb2Xz`9Qhf&t)XSm zut1(x=gsch^&iEB{fXo8l}lQ<=J#fzTQPky)wVEkKFQMqj@Bd z>-uJYhG&3-qm~T__2ZaOW6lfU_}z;z!(6*;zumoe*1nL_{jYOn6Lxh(-&mFZCOmNH zhv+LyCDDl|Be|JKRb)p*Nb_@}5xi{cG-^tkwj#{BjDMyC!E(CRQj8;3$YakI`FHt4 zTyWQ?#IkN~L*AH)M~_-rUDD1ocARz5Lmj9GsI;_zI#S*;EH0Q;vT!dJyIudC@EyJq z`)wQhz;w@}b*s3O^50xBJzZsfZKmOSLo>ZV?$_Z`YSNtkL-Y&b$daicw9bkP-|kUv zt{bDE-lpx9q|Ln&rbc$gd)*pul$PG80R)H1Jykv2h+m8*B7X^+y~zpP#GR5ku4y4G z%?*Xu8Tq)<{MF~~FS&AvRc>j=txve}a>^Tr$^GnNkG5OujQ!C~^c)Y|l&4umToyfg z23k__TbuBM_x8+O%Rft>v`#@e#~BQI5oH`3dvuC{vB3!X!KzTFPE)>JOQeaGx0i4= z>)rpT@??S=r;;?@*7V_6q~B}k(D(5=>>-G?GrdekayhDuFvl;{T{p?&%6v63_qSaS zRc>M8krkFS8Co4i1PZu)Kt~?=wUc*)0X)A0U)tsVEewC{<#9VZOab~=bMW7`z+E2= zc7pe$DeEni0^zEI1WiUD{H?)!dPLJ1wk*3FWnm{oIz4Fetf;!t(H{l{3Ml!F{kQwf zjRiRHt;XBi4&-;WO3k530x(6dTRw($|Ft_jnUjPAc;4XMHlwri-@o?$_pf?Gadq^k zVz>H0HoWx{XW8B6<^+@fp`G!er}oZQwnF27y&F0GUusD+CRGh9q^|fdVOIur2+q@| z){cxX+rPk|=3id$XKJ;C^0~tP8U*3MyyD5ov;a%cqitMD>@F=NIx_&wJMILv!8@4UZL=H$+JQNE_`q%a(7YpH*x- zgw2kyI&o~Ht>e55sBHcj)G+x~pFr@86UF=G^(1DW_NeKezBq^8UJ=zuCJ&bkEwMR=!uo`Io&PbSD`3*9vX-(K=-! zJA9_0-WOZAHVI;s5<6_4BA&KrM$^lB56fwaaU&eqYg-{T?T6|o+xQHx zB)_J2?0c{nvzt?{eUF4rJ`*dK@&eZIYuI}tmUj#~y`qTwE@S6E8f-^d6!}&PX3dc< z_B+a)PgXO(t&doL!kd4AX(=wH&L2Q~=zJ|N3|~e34FrHD`ndk4f|zG|OJ+^Nr1onQ zyyX(nsy`P_`_*E0h+zGMse`w(#BtC03%5-O$-{!DW_4v1_&s@$)goagrTS0r>r~no zM&3#|;h(q9LfV)X5UWV}x@`qCZ}L82tTfXwq-&tK@`3=A_P24;brHVUF)4s8+A(bN|c`6pSd* zgU}}3GKBXc-+ickmcl*%?%q96Xus7SpW)}g5TtR;hIlXNn2ovgot!oRk;-|preyWA zdMVmt+BK6eUA$Bo&!=Sh4GQ&}zlkJJ9yblzp3|!`PEf=GN+&PN$?tY$Kc^5!-jLTE z?I7fw9pRg_MQ_4GPYcH|mq8S(cn^`RKZTr+vQXv4qc#n5AiU{kmcO($#vYLrO)=~MRT`~xEoc=g< z%3DGEgE1lPEItMR{g8B*XWggn`-By|raZ8kc_~a}bjgo+3Zh{23t=8S1Og~ya%pc+ zYnv}D9XS*8e0eQ$4f(Ei{i~PZ9mkBFd9&H_!aqhhSl-}kbxOJfBCyf*cnC6;+i{ID zN|i44&~Mp=(Zjn~v4UrOJbL9~KXR@!MY9CGDflPQ;!$lJZS5MHwYxX)D*bR7C*bDH z9&To4G0}QLO=h816#2>Efexw(n4XRPC5!t`O8rxqHC~2wMYPZ-^#Qja71{_Z{;fr@ zQH(_(ZrwU7k#tM_D_#eb+c{X`M*|l&=E&$=6?+yibRpc$G(71#g!3!|)g1p%C@8%ACVBb66r04!@ zA(aewN(Vue=r#m}x|top1Hi4;{Hd70JwuL-Sfrq{)LUY;+()^={?vwJ&Z{50-5L9| zl)m)l&Hc~c(QB*+Ao3aAhHq+`$j&00xKmZOXTnYj^V+eQ&;Kczt0-IGUrP7@0i^e{px3OrP3y9EE`Y{_W~|r2)&GJ=JhH5D5gB*d>|3QkH?6^dbr&VdiPhp z-sc{@&tXdLqTy};&W2rgpFRJ#Bl;_a9Vnwe!d?c!tEt29TF8@is%H_KW4Q&Q*I(!+_hTqu(Ckbf<`Dkj_Q+|!4zlnf1y4&d)dv{!r$>PP?0~41 z!dLMEWo|v6>PeT-y$5lXCZ#(2dTS}N)!PE)(>a?BY6czITF9GOP`ZL11P-(0@A>Vs zE;Do}3uyPrWihC19aYd72(q1*S`dyD1LKh?VZm;)Sr5+B-DxU`KolY zYI;)yyOZ4bx%H~Q!n=0;mo}^~QPFPoS03S=4c`Be3lBss+Tqq;ar|r3FT|4L%)W#; zJJyYw;uflBQktU1-21QBpr=ZrV<|rK7IQ}&{7nAY9s6^#B@BK67MSYgUDkG|@Awt; zBahgd1&fXgzYSqZh;h3BCOUfrc6V+iVcVzE@I{IENZ9h{z<3n1b7WFls$_Uz473rE zV{nG2RmMJi*=x!K$KVGRt(QHG$;rY=b~om~r`SwAuzqeLQT(4vJ6$~(awtF91jBe* z$;I$+V%c8o<`UEC&35wq*tQt;9aa)ZZ0w(5Gc&)Fz6xW zxlV98BIx$Jq?zxm+#WkdbQq@O>_72-ZOpW9M;I(Miuc<6nfnPrgbB6)!A*r7p9XC` zxzoX#wX3hd`Xy_ekBBs*S6tX)i#^D-*hxkucJS~E$?u%}@J7^!6U8<~f0tcIUWM5z z{K}MAFwfGcG43jnTK81H^U%))S%0?ga z*xDB9-V`n`NpEVu20&Zudq2w+pT83S-)@NXFQ+I&JN(Rwoa@TF(9QyM-&!(H{HaxA zgIP`JoefjAkJvDcI^0R6pKLXqY_D~s(vBRQD|+zKWEEx%_F*PGzQvd|IoA6=Ch0<6$|!9il#cwenNoKaZj>z%Q#R*HhQ2d%J^SC{=VO>!4~ zRoi7?Kik`4t>6PSl|%R1BIZ+}K@HV)3y{xGLvQ&_r{MDMlV6A!;U_$T4?EE-vtp~o zq30_O?kUB;Z*|Q`Bny}G&K$vQSHe|fMR&SNQ1DYkrYM7Q5P&)F(ov;#;dT>jAEuf6#OW(PhQgomnOJM-6w z0~v0i{~g^04u(T?vHz~SUxm}|PK)S^U}<$SVZ9J{NLE*m0`+{5uRzO!`)-0MH>^h2nH_;*)_yrZagq~Y}xN_gFik9g&7oscWzY?K3&s`V><;prgQQO5Jic% z{c~*)8H6#Zc?Sp30w9s``|5FpPWnHxPN#jt6l74nLL`JiI$KuPH&|x_8i3NfaM>#;a70A%4z)%mb9a>Kf(f6ITq69V!E^LA?m$}HhwOvo zD*Ci+B^C~`)nX(t`KInTK=Mc?q`97C;6*00gg|E-634wF$Zz`mVT;eD zT3Ro1+-7Lzb?mt*-2y;BwBw1_0&IfsZ6+Vh8&!Qr#5lDO$u62Fa{IlZC{TWV_3GTv zGCz`HbT#C9=(#x2G1+9)nVy$6K+~@wh3s)s+y&9hPYvYwKy*u=K-z^sC>{oL3udF` zN&v>!>X;|q#rO^u&1?;S(x@tyrF9EBj&-f9pBlfdCG}9Ufa|~5`$P$Ur~2U5qPjfk zOcw=Ux%=!q1w%4b=2&tg0vi5?A{YEy7ZDd<#*oHR zQ!~jbazPW?#Zo+i6ME@JhpiMMGOK%M+OXp)1@}x)a`7`z1tmrg3TTbT(Ny*X?cL79@9XbuF!L%ffS&*DBiue2PS6DKu63VJ2)H!CAbHg*%BFU>2VhEQ zyN5_hhKYn}Jv-J}j81yMA76rMajpXnJMHu!__Z~;v{qY9a0B8?IpbxK50*O?oBa1U zsgJKUIp51kKHJDIaoMcX5p$NwgT$TZI%LwSwv-kb^;a;QJPCv>RwnfiJdoTT$VeaL0T(Zwju{i5!BSwkf9 z;NC9=e|N`A+`l^;59a}0fsNGxCi?>cUd6|@HN0yvH;VcVwqkpe@FXM!*uxK1ycoxM z_m;hG%zO9ieyg}|W}mt-<~im2wMaInnL>W!tDix>(i9-)$yM7BUlueGNu%~EfqWvWP*TJkyb+j z67(yI%{7;f(=%4_&NfwxI%8kxx6dS}nxD_Ou`>Oob>OZ+=lvT;%pJLHosmJ6eh8|4 z6L2DqX*vP<1*`Aknny4dYt}fdE255(FG+Tp&WjxupI3|y?KK!!I@;F{hQZeNGK(eP!VXQvMmw{o*_Bo|nAs8RpPKya%rT6Sf9Y7YEX7_jY_H zWRx!jllDm_)2-t(dwV$0-2O$C>K>D z=!lwsm4zmib^l`w$5zk^wsX%UHbU--Op^Dn{b&$p${cD<>w}nhk9cGe@3q7p5o=MM z&W*M9^@Ueb)<*;PKxw@oa}vQOJzI%1Int@ph>Lf{P};PUQ0KQmEH^mKA3`J@l)inK z6eCNL|EL2MuEkG9`JE(MCi$^OYdr)vxdfSLLJuHSU_~BKobmf39B=gWm4YRKZ^Txl zRG;?a z?k>1nQ7djwsg!_fgKR#X4<@NB&C5aQl4Qc^M_=aPvUg}bPpeVOBrRvTLAGp?$!2$j zABaLKH&#y%=-IV@ZguDruUO7()Bcn-7D{>^BHs)yYoXu*cykle)`*o@>io1WZgX0> z?xBVEPc6Dj>2k@Ag(jMz)4zNu%lamv?7aJQrn{R*dmYnei}~9RXwVv~HF?jB4WlpK zJl@fvc)7mfWUY^=ylJPrY5@%f*f?ol_}$I8vM>O@*aP~eP+?k)$=+j zp6{UeWRfIBZLj3tp-;@-TwUcE?)RQCsNEyr>b_w%qmS)MQ3$#RlnnIF;TG$&1FKwT zVeC$NxfJ}9HLZH%`dLuC1Y|qm1zo<0(t405G$^qaqNPjFBj0~m`R?GB#%pC7rh- z92(=&bi2tKG^S*pfO1$;usml{iRFP`utsz1{q8oe59%|)Q%4ZdNat;)MIR;C_PI7Q zsVo1~k>kQ?*vtwqO0m6S^p)p+WqQ-Sqe7j}nOcwg&SvsU8_CkQoXkwKj6B-=$V_kl zCnbrk4AaR+<~=%AyDB_a)x3{8|4MjoV(nW(Rg{zb&xfuteBA~PHy=&wtPXvs)`B-N z$|QeMR9ola{p|cWt9fY<&~fzf8S9dT%dt#mvT^7*p(h{vkMBEBCr5c?LKYcGm^`aOugh>VTg|Ur;tJv~j03pm3`;ZI#L;OLUtTilNx$)# zh>h&lJxd)Ge;wl8GcJzcOxk~ZV7EyCwvNNdD?Q6TRekTp%AypqZAoTi(6$;rMV5?Lfv38lNL5q`yM2I=Egdp7on%)>!hze1i931&Ml<;S+5nN1w7Z zKVBQ!th27Ke_TR1QO=z!%Wo8-PTVZEP0M=^LZ*GU!K~@Guidzxc7fR73)2sffPdYH z!IvJN&tX&8r$4rj{^JuEvuDy!vPmUUPuYn}sMCu$N-|W^lBUc;#7`mqjtjaq>iQve zm5y|BTEXv2T=-eC+#LxggF!K1`Z~8yeD_~f^MjK8M6yb(-Prf%5`R) zF=MzWLaWF1P%-KekxX<6&SYQO+x=h)lZ!;sLdSCUTHU}GQe8DIx=!?*9HD}vy?o-0c9(y z*DBW;%kXrxjgq^MxogvV`Dae1?_iprJSslugMB}|(M>ePT%C%2+P&ku~OLgq&UvL<<~ z0BE6Terz;wq`3xt9;@LxRv~|e*lMEXw{Alo;;r#2%nZs<&zhxIbr4W9*`^7*b~VE{ zjQZ@3(DSetMUatO?B`Mlg}wyBW}i$g2EbB+%gNW=Tr`^7kwwquou^KWa4pG0a*d<# zl_k>^?{78H@OcHh4SMOOE4M~De?NDrQ|Zg~Mnn~ug|X#B*t=*rBLbIj-sY3i-vYxQhF||bav=-uekoxNw`puJs~)kz41h)7oY3W3ArTN6 zM20cV@H>zXHPA;puHk9|h&w9_EvT*KEt$y=4`;m#N?#n>vJs&JaAqN^LonipE{?WE2N$)&Rx4) ze1!g*@g(be(pbxbDJ$A%>eybBfia5Flvr+F>~qGHw|oubF{a%7xVTX4G-jGwbVv4Q zqzy4nZL7YAYIQ`=Yd|p}BwEfs)ACv_J6}Y}Disqcn-x@n_G!#N80Me=}UWKabQQ(7>)_Gn3h@ z!~a*k@hds|cCSs-XKU=%8N)2*LIx(1vb<8vJ7b5qkV#s#reVMxHrTa)SgDrF=?!aS zWoz4x6Q~K#DM%_bJ1`A>UqoT`2a?u(9~NBa8$BUI^vkl2+JM9#U_(r{*z#Q@I#GSF zI{F_lmF%q`Mfx=5HPQIegwhFiZm;_9M9Tc8G4Q!PqP}_#h9>74a3y!8Njqev{PzMi z^{VGuqC=TMWvPSwIHvzJJH7x4$j}LsB^2m z4!wNdQRZN)wm|oyblNsD&$D{)007_-ZI*05RiAp=L()~X2L1sQbm8UgaTGO#>ag#v zIDUJa6stl0zxI_ejs^^VPfZRqSwwxO1AHSPU>cK#$y>ZBB8V~qp%}L;cIIcj2>s0XT zJY6XHmL$>Kwmte|YcJgd4T>n_pModXK8yj4N&nY3m+BYuouiJr9QkeJ7&9GQX%vLf zrw@hjVVakst4pwz^>Y(BGUA74@t)WY(V-{thMO0XhfZbu@q25R$%O8|pSJ#a%+rzb z3RSNNHz~Uh+vW!Uuza6f=9w!LGB8G;bQt{Nzo&acZgRp!^jfBCM$j@ z$6Q2=7q{HLsl)emo*egdMB$jsADu&g9@Q2UJJ)Pk#Qv2_`=@dPqR_uR4FaT()>&Ei z?7S#qF-2weD6c*!ELW~x)J(XIJK6`h$wO)ZbuXYL_BS8#3mxqaYv|z zKcQcEWOkp9yMBmy$a$J;5Q3to+s&RTtY5%pqZ-PK zS<*G~A$RD?P}u!Zv;&O)8$V}3pn#mr59EpIeD8s9{gIGD&2J5>*v5WgZu#ZxQYK0L6WvbXF30SPeSaMzB0Ll zdy2T;OPwp1u^eqX(ODtxdi|CcFzc;7p$nbxPHnF9D!dE0fRC35f-N^6k2gT|4d&T4 zCzL;#fextU!Cll_4?p%#4<1d2ICBx-Ws=Dz&&=QIG6 zEi%UrGj$G}=lol!OZm(@@qk?73)kkW+ZX|vhey(o7esxP06A=+NYLtKEWm;}-5&N( zh5AZp8Zn%fK!xf(l##?`)2J($uY|d=FP;q`xmk`oi zw;>GI=v`z@jqJ9$yH3M;=#zi#2Bh5QK3vd=%W5CQSqictBj9+| zW~J-1H>Ry^#kXGZ4UhZ*!{l0a>CV66z~MhE8`+L8CsIo6A`os=>)sH0SK6gTGuo7o z;*K47xVGb78(MgZ?&wh}cwDi?FDRmb?O~B{7r(?QtE&<#>YCS!@AK_n8*@eU-gNyW zES)qpd0VML1=g^YE^tTmoutQ{>BFD#)~Gqk8?~rNg%)uC;c{izylvTm8N*_Nqdpq+ zn|nXpjlTE8bwhagpfusmrNy2T7<5;D4PZRQ^)JT`OP0ER)yw_DELxCt)MOz8gz3qsX(cLegl0H6nrHvy_cTBBPx| zDs^ORs==<9}QMQxx!ns8q3#Dz{3^6&)FAr`x;A9Z?|uXaN-zjywSxU|7UuNUzm-dvTS8Y zo89;B?S*Db>p-cy$%N&wxy^qJln#xeyEbSIfC$uEzjan*R-O2-VxyZ6Iuxgu2C*Nk8g$9c`)SCjumgFqhH8%$2s^W)v-%P# zY8MO2KvERjW~3nAgdF92lPjWMqouaEWcR zq^qjw@vZ@d^O|hxxF3Tv!+g;ZMktpe-QbH5GtM6f;W_W%Ja31(U{l`PEyQSY|HPdy(LTa z-V}L)@7}m^FK;t;+r0S8!9joQ)z*vaTEy-=CV z-FiZ*>&n=Ed@3kseQ|BWKzlD_O|E)j06R@g_wo-WW^Wa z1r()jE-e29sJMM$Z>+EJ@z~$FyXe->ZA_%|CIIz&&x<4nU6Kf@5GU-35eO^H*h26S z=PM6FVAdQmwz5-(Pf5{SuggFh4kz1E7fe?-l?1V#M{e{R41tYWfu^EAvgJYZS^Lw_ zQkPYWe}4JSztsy1>f9pX=g7G&bAD4!44-_p#=X^+jUR-gh-njRp0eFHNWfc@Lt2(dp2`Xs*p}es3NnTb-53A`bH}=P+ zFxQu*V)w7O#O6OS$vd(kr48lq(`ZT=U((=F6xtIptPZ*1IN%LcS}eQ4ReNLuQo<~HZm-pQ@F zLnQL5yT;gQVv>;T&8G1NTvwl3F=KZmKuTV z>?i2T8j*Z!85@21>^aZUrP;qDc?3(4>;Pe26YsEfEO|Yw!+TJSc#%&tS@T1=4eQA; zbw7eIc7ooe#vVl_AQy87{8N&2lA;^s3Dy%dbXa)B@)!h0K)BN#mkCsTSXi&?c*@|5 zql-zN1dl1Dtnq=K#wO!6 zCG2i*d+DAPOg(uk6Np%8a5Ve-5;Ip)x6nSO>cGUBH>Q#7Awv*#;WJ)bc=c=XwGLx@&1IfVtw&APc zw3tp_zNy9?)DoWRq4JyA7YMZFD6HbvFf@4~SLcF1l8`~N`g(=beSXqwF+`)AASb#N zvjS6GJZj)f-lu6$)x}vlOTZGzV$0*ETf{;NneT|MP^rwkC;*=**r_H*~B z47Q2~_v1b>Xq|B@8eUIO`e0>Lvc8rKG;5f8!I40oeq58P>^NZW&%Qu< zihF5Yhz{4+7_-V0|BNTQf;16uxn$@Z4(mNauHNNp+cLaHEc&3m1vg=<$9olZHJjW; zN7-dMeNTZ2<_cEA131a+eXY)ahvHkK_z~y3h9B0*Z7+N0k+I zqQb`!%h_s7u4`$Y75y@(RFsE~W>01Uce&H!;t8Pvrgv_LSkuKBFI*MF`TMZ@``;h! zzc_kA^0S$nvu57*|7Qalwue#vy{fe^w1Z6+0pJkWAIP^z?}FAv+X0UbcB@~RDVFo% zsBg&(;_mTrj?IUmXM3XM+WK{Kj>j?5-rShyIeq6Vj>)o0|A&M1I{j6WxHPbOzMiYx zgXfXOXq;sC?ejjaZ$@^F7c-oG>OR}omU%OHm7&l<=EH8W`n}G2-7Y!w#2#-r-&At?34@7s%DmXKkS>4dFN5X zk(*1kzr{U!Pcxo7tHXQzRCeqt*}#oGGtZuBia+Mkx~O_v$HwJjKG=(%Ivy`o9DSSc zNBCOhVjUH%$b~z-x!K=aANu>sVBXat=|g4s0D7M>a6!KppG!QWegpP+2Gy*+HS%6R zpR}QjAvolG=C_KC_O&P<6V$@IY)UpFy`k{a7-D%TxF_cEaIeJ604DKNh)l-31CVh= z`3!ob>AsGOB8#Q=;wux-+n}qW_lU3{a!bDL6;3OeD7hwBJwzW|9zv7eXu#(^L~Bg< zfmw8N@M!4_MD}jMZN!o7x4>wL|2gou@k~JF7HvDfxJd6KbnZRR`Rp+Ji;pMo{z7W} z|LGS`jDiT7E7muB@x@=;qQ0h^U&SeE|1E{P!yf(QjMq(t2|fQR9XEdY2^fL`?%6A2 z`>6Au)1@!|dEL}6+0)mO`MMb}eV9`QXbr95L}nai$j)7q&)rKhwWJ&CEBn9k>WU)-6eI|Pb(A@lViWhUodD53Hqm*gKVi}c$q=&nJ zO^=}$(B})J0Z&{G&0_r^WwvOe(6-Ly z-D>xvzS$;i=-)vnoC-@Qb;Z^f~rCt1zr)u(3gfPWQkhq>(a|Y zHTIc2-BJ)DRN{qUUB&fLy^97q2Dt+879!FM6c?D5d1&=;V%qJA3Iu(ZxZ=8JEUDG_ ze9EA*bpZX7$KcJjw6XO6waOOEvgNyFzN&5|FmEev)`UX0p1mH?irAGDB5Z!K{&#`y zf}bDr3F^$_2)N?C!L8PbHX*Yq>K-PS2xtpzd2DNItw%3M8s}?ykjCh0j01^Oqh+BK zg)&eIkP4kw6S)HKoqxWo9nBeWaZhACBFKi7yUHgX8|uL;48U>Vd$kQoq`Q?~>nB$QucCD$JpO(gTLiX&0WUb#`1)D|!Opz*sT}8#wRS5Ul zGmABnmrO1@kJ5i&1$?kVg^22JYj7I2V|+ckT@AVjRTKS1*`Y$ULKhZ1 z>}h-qovYbZ1O#%ciXTXX5PZ<<@j&1uf;}Ot7+Z|L= zZ8WRtEE=f^S)*SExGAbGes+80^xH{yc%VaZ%Me_00;&o~h7dxjs)7{zdhBJADppa% z53;?zyDM1;`&TMgj00RNUf)B)=H&7@aF}1+WwhyaM0`F6R9Tj}&yN{l&W7BKV78Oj zogQ-I#XnlZ)4LmPvMp34Awz~ql>cH~@4YD6PL`4nQ%tS$e4NK5E(#?@NH-JNGbe}L z%0=!-Ysl-k&&Z;sTH#X>VNQh%Sw|M_up^GDEKo}macfJ4>v+&D!(n|dvlT<)KKI!% zp{Z>CO&yW0Q0bV*N3cSTlQbwtL?Q&U@lH2?CH?7B9f{n5ui?hZF7?)&6Z2b;_k8>k zS2Qe7Zb@fDs5ac;L6@{d*+gD=J>0YbI|lj_xtWWWYGDnOf?~+}5y)@|?vmm)YK?Ty zq-L$1*NGv|EG}ONLI*2H;PuX;h%JmjF_8q(3tyQPLzx*0IXFl18bYBy!ZiZDhrJjF z%PwxM&lj{xR23(tG8e*Ah@z%x;XtH?2_#z{;IMCaj+Q}} z)~^|IVv31t=wVhIlFcYz3zLCzI2V?%Vwf#x=RA38^SgV89Zr5YzAt!#QH+sykCFKF z7~@sSA*q>uX@nbbqsY@bmik?9@OLlzPC)@Iqp@2qIGfusMtMAP|EY6N=WQO`^zQ1y zPz7UL2-{6wtZeHW^p@Ia!drJ&cNhLZdp(=?&-&*^H!mua+fPy^pD$HErd8`XbuH2y zr@K$-T_s3r~R`r}&hqw<}Z%xQ?f=-5IoMES(G-nj&kD`EYvU9Ks-@_A=IHp%j zE}5^XadaNM`R6x3|ByX10dYC=rX=@DU4P|1&TmC|+{mnm9_z(@7@;E*=-sFb`FWDq zXINY9Rju4C$SsTva&F9h#{foZ)dyC(uN`O&{|=s@o=Pk%8qQF?F29E+4wy&zS>ZaT z|8o;+@0U40UtUU(|LM_;TIu*DKKVCL`P$H-uAUmnhS+Ih$za8l9iLg~6HkA+C=gtu z7JX_t`OFK?2+eb{apkpgGNr6ek{Nv|l!G*sEcQ8htZn^%GG2QU=?-j_JU4#yc>r66 z|G|nq+mg&>GA;$((I=rV^%G9b+qc<<{av6x1LjsMSzEP>#plI?MOrwItSm%FJR^J~ z_k%D5emSiXMn2)+=)AqK2a%zEfupaIaST|BTkEC#_n1dQU%^#%Pfl8?wh2E+Rj?2#au|{+0)3U0@_f%<0Ps1WmyKJ{dSgb-p+^o&A9y(|7WBaUrbwOw z-25ovM}r>$_xJ;+=O@qFJCinUd~&pdEzFYX7Ift?ub`#7i_k!5z$J*gjvJP;-zy># zzkk;dXW_owLG6{^yT!6mn@_hs`_Ujg0pq^W*$)((cL($y)wc9pX0(J$A#?j-J^q?J zLEbAwQd}(%VMiwisHTfQJS4Vdza{K+^dW9;~SEv^%r7Q~#rG=+a6X zr=pI8hRf$v18c! zatPYYw=zNHTX$4=UyD8(WQ5g@ zhk%6>&CWylLQ!&Y>_>xL7iaKeC=PQFys#Cm6kEGUA|DSBaotD8+FDKZD@6$pa@zfGx__~bHcG?H< z-@R|iSUF_5qhkT^(*=V^Y`;aU-O*iYqGKtAS`jX{VUmNkk4i`7WtLX-_4NUPGc9~+ zmE#lhX|Ob#;O=SkK9!U(Fr0{@j86xFhE{^aGA(siEa^3yJs|YaQ*h?NdK}TI*<#vz zSk*z99}9*)sWrHdGHno~-#Kw%%mVCrD(z|x9(!b*^@)ag`xiTp%V8{h&=QL9|E(s{gP@G$NaK_X(V0OoC9 z53LOyniTo}@$~IcO`O}?SZ&ec1zS$N0%NsDdP)_o7eq)#t39+TrkA6rAW>;WL`*57 z1W3k86)j3?tzx-kP>WF!B1H%wB&CSaQV0?yT$Bg_hZsWc_nGOpbAD@mf3#~g4U?I7 z_PgI_@BQp&7xn9$Q1iw%Re8F(EfjX%zGqFf+eD}K{}X?O zrTPvkmETz%RvXkclb4Th{ehW4bKSV5>-L;IxsKIvnO}LnDzKA{^fnoJ?}&|j!`G$T zWb$Gq~Fl-b)Av4PXY*+v7lj`Tr zv0no|v#)C4?E0`EtoLIvX3@9%V3Y0q_wh&6eFvJG4co3B%=ZYWj9RWrP?MhKp$hgc zW7e8gY^Osw6{OX9%c;zBVNp-_7v=jZ;e7*Wo;M>$J3Z`vfpW74YHD-_d{nHuabT?+ z4|RWRbRVH{LS5*(7ya0#2vHoz1fil9No^`)+EQ>@U*ZrQg$aN*b~~bXEu7E#_%km7M+22Hzlm^4u*%k z1ka4HEd4iMWNxoJEX8jipEw17sAot+GZq7jQ}h6t(GMT2{bf?@?tp`Z+(`+so%IjH zhu`e@CQ4#y^BjP_mqyGO1lrhSu;dBz-RwY|TK2qG+~3F=-1hF_Xe{d7#>)%dtUI+Q z);H(8M{`}*8NbU*#umwBF0HfBdJbug=}UB@?y%nRy2iWrG+ShZ+GTG|8(P*Ax4G-p zuJp|kwYb8n-0Xj<6?F#M)!Rl?wGy?HtE`F8FC3qmH+OtVWy+cGAmaK8o!`CxULUc6Q{%zJwR@~Vi=x@E+z0Qlq4fj>de>w{WcSp6Xy+3X6xFY*HRHj^f$G)9$zYm8O75%jKlU< z5-Y`^&SNirMQX6IC77tk850MTXjp8Nlx8zbiR6~XFT)bgP8uM+WFgx z6Q9nn|9v8I%*GxJBWHivJn-%h>P^m@=sD?ZT>y1Q;Jg|6AwTB?`i>%gsOevq2icV+ z5&dm{nM`w&m6rH>ukQ6`l7d$)|7%}ZJN3xw<&$k(04n05FA$|I!UL?L+e7-@g><`p2_p0Cdz+Xq9%UqE(oVWnpJXE1*F2BI;FC z$)pco!jYiS(KrKjtA`+*?}9~V^HP*Dwnf(&GlG#zSZ{(4#JANutfWVXvCC$XA4wa% zUrI|;S@Z%`{xbg0ZNH7l#LllBh{!*q<&WAgP6w-&U~cvB5w{BEn!arAbkQFh>+5R3 z4N?he*;ZU=bYAdmrnda;BZCidg+;4$J>}MDNd6Ddrw{`sg&a&8jOmTP!iQRe#v56l+Ll|JhXIxUbH~XGQW- zm(?@>Z6baM{Z8ULe{1FauM*H@@BjCOC8s#%gm?fEm{+ara&=AYcYzPzb&_u+vGd|f zaQgMjd2cUQ*`Ch%V9ApETK{4*ooKshZvYLUVr)*-_IInm%9GnI4I_bIiMD|J+dU%^p@}9&y0n<`PL)`4b z1B>$j2mH>&iTsT5wW`jq)#G`av(cZiHe-0z{A}5$^DJByE;BQg_Af~p9C+n@<*x2< zi5Xqt%*@K9i@PI(5g8oZ0#}+1xFVXjuMj(0?U;mRFA}({KzOq?z+de?I%l4L^zg@= z;q~4V+~k5T7ufMzn(b_qWTN9@AwF1OOi)Wjg_V4t64ht)2wiq zOf$!egRO@CV%XW;dgz*h(+jJU$MZ5h7~Xw4QOK!w*-KiQaKHSV?j!qw)X3ibW1>4dZHX-`k@ zok?p{MnW(`r$H+56epMhGi6(TH=M<6ukNGL0*b#BIttWA-8Qnw?>F#`^`0m+p@Wf0 zRLpE)6tlf%Cnod;RbVxAzG#7=tp=;mY<~y1Xd4}f7J-9A!L-HC6-(QOS@bmvonr+~ z%idE{1O@^bl|2D8^&NPW&Phh_NFM}CJS@PMxlfMi)j1oNfzBDcDbK2=N?(f{`XM(N z6_Y-6Y3SWpfQ@1_be~FnTTU?`3Q*zLh!f^Zal1_PIh<`QtX=U5)|zU@y|h^;G%D zwwI^4S2bgh2*cYJx%U@b-s?be**aH%@WJuWmMN#$3wWrXMd4~i%(h{yO(w!Tgh=LU z{Y+@)Hh}z}4w{_Hg56?EF=4jawXWlbxX@GrhNpfTU%3X7r(`xv523SyUV@93$Y8R|T@HOz3HWbhR6Q!fwNP2c zMa$x&;l{7klW^#wzx-OrSsS{jOn#ULDwW!G$fIm^^%ncL@HRZr_3*;iF))OsTve6m z2{T4Imtbrdf_1o&3bF1pD;{~!@uf_Kd@jY(c3@=*0|$&im|FC=&tj*QT1zqRW%^{Z zu9I(^nA0eT?2k*J{3t&}tQ52~9ok7EutSm2Zq+QAH)Z~cWJI`Cs@az?bIZee51^f0 z>1aU^EiKUnndgoT)!l!MVBLc(M#MeV?>4t;4H5>;O68ORe_bzu_b6b~vaJv9zGDn* zy7*L5g7#&>`^Ad+Li(w}H|hkUCdOotw7_L1R1-%Hhjpi>O;9LtiyX~AFfli`B@HYj z2pw$^yLalJ9iuNk+1aah)dKm!6NBVx=s6l&WfuL`%|W<>(>Sv&>Lhg~@YATaX*o|c*O$%sg#PWW zqA#?$FXbt0bR#}m$f2YOBrR=JA*ST@R6Z~*=TG!)t&3b8`FVP)=zi#wfdhSsXKt?_ zHYvz=1Am*&-BDcY7J*9$R$m>rS)=%fKaXOp>`zXS3X@~Q$cIkO5a(><2YXhpt@z9S zyJqZI;?WXSz>F!2>i>L^*4?yAp$W~g^?V<++?X!u`xJ(>R9v5;QwBy2%B`NTAqDvs z@nH3jilNY6i+tcs=cB>&v(KmhyM{Nvg=BRN7k@geJa*T64t6aV_5WmQ&hcDJ@{q!q>^HO@*c#M2oH&{Z9 zM`Fgd`@*;CH2Q)}7P-X()69`L7^=j1Ae+0>wQwk7)e;dM-^?(s=n0~M1f&~u8m5`J z^VBC->Pi`JWTkc{8XwX%N~(-9+$Y=(F!9Lc@8*EI`#&Z2PgP~a(CV_mdwF8_oFmxS zV+ecunl$UFYIJ{lW^7g#a?Xo2B~1yQl^FYXn@v1C5PC-KTLYfJpEHpv5R;ji7~XL@ z9XTF@%RG#yHFCt{D#mc*HP2>``b?XRy$bOEFVxRTvH_FbEWMQtk8St_$KK|dMVZ90&IKL=+GgWn^9WmH;uYReyd5`X~$(0Oj6P(@OU-;t= zoOW+|r%=s#EK_h);!X6Dmxc0nO5PgHfxL+Q!vrH;f6IdX5i*UfRlivi@)P2ziRFINpFM-|)mu^ad}B{)g7XNUFkGrw_8{XY>sz z>3ABF7?THnQPZa;SQYFt{%a8NZjf=5now&q(&%=h^xEZI#hA( zx6QFjj<82+M_=U!{aYNet6OXShputSU^}~&Uig>{y|Odz;M&^yEKyHSPhbCtIP9Bs zbEg-lAKH65e9;fTS64Sk>wf6ym@XJF3Cm>*N~ZsN|}c@@dRd;k6M+N4`Ql@8S#dlJtSTw`5t z5OHE0D@tK_=vjIS))Gc5WYd^4n&0@ve$KarozlG&5EnwfH6G~0;o^d%p%ZQ1cNP=a&@eBcBjNeUixnZE*iSd#>9vyynm%wSIahAfh zr@Aof5%vw*wN$3~0(EO76?@W3HmZK)6+EONX`Vk&NBTdah{S~`6)C-_u|U=hvq3lr8N z>=EsNNV)$+a%|CDLnZ$)ot3Dzyu7jVuh6gSaIkyrL>kJdGVMlG`=iT(S05S29{EC| zdEGojmPNuV%5Ht!N{9h?s9gE)t;QN8`0T}Mk_S<=xpeBx_F+h~PA&3A zNCu_Iddu-6h>slm5$&q7x3MfLfyupN1y`L28SvF;2eSkd)%*2dF91jz76^o}trReQ zp7`+c{c2(z5=5bLB`TF&_d8u`bIaKCH98@rP{fu}m$5Z95-}MgXe>X?ZQ8i(j6NG6 z1cDHze@;kn7~AY~UxYCf8z8;RfB$|b__ih9aYLf5GF)krk8J$IFQlTqU5?d2Y#M)H zS##pQWF&C>N(m)N2mL%DE_zarg;u7i`o)$2x@J!_0f+EbNP81Dm|b@{*m(5-BptLf z_xD?DZVx?Y=rrgsWMtH%!+>auy)qE*P!6|~*{*jv;Po>;O#Wz2ViCn(*^jA&kOn=d zfTTxQhu9C8sTxc3FJGUD*Hq93F~zQ^dPwSzDBR0kZaJ6R}0`F z7pdpA2eC?$`>7D1wSd`o(?8Eigk9=M)o2A!o3Ltr@L5b*n<``!o27F-71XYm#oS7nsy_~1`FXVn{j&-(i}4CSGNakf zsY#WKM(8a2DqF6dd2t&1u<$-az#^J~vQ+s8X!(g0=8$JB4DAcx36KI88B*+b&cK}S zQYU&lp*{g(#ycvsquRinqvrlGHEVuRig%#H+RStvjji>36zEI6rlA>?hyBDZjSH5o zF)U$vMU&M?iN@;N%X&&jQPjYELlH*asee)jY|M6?fjuN5DK2Y2YrreR zCS$WVQJ*M}aOo$Zjd?}MS+UF|OKBVm0OCyV8NS`G#%1j3j|11NR_fGikP07ugiih2 z*gUH=+y9w&h6-4yLib>;F$4;gbNDU%<&6cO;hQ1kjH)pRT$7LK+*s%)`uoMk>ldL2 z)yXT;vn^VkVi5w|pS`~IVu7bT>$7uJ1)5O2%BNc{+1TAeR{&?q`9D&gElE(z6Qc+hzkDq&D| za#hIGjL^H~DW4RfRXeFm3VJa%oc`)tUWPop6PxY6h_}joTM5N2=TH5#I^0fT*ZmGB z>kve`JQ|Y2e*|g*W9YZkQvpoix#UjR7*m%1vey6te!g>P_C;)a&}u3_d&1DCQ2x6! zdZ`<-111l5Ke}G)JR;7c-@%&0vqQu?4*EmIOyW)?8?c94TVyJuR`K~xD%K6FbSSF3 z;A_Ba=gIpA-N$m%c^_qcS2q2vX=it87cDX`3WS)(tMz0gS_&K4rV!J)rW$z2(ZY%= za|0hl1-dP&D;*&=97ZnqS9)5oz7?a#RAH;_wb-GT3IwtrnJNOblBE-!^7RD@e8ZEI zU=ae;ijL$*SZvH1g^d#mS@_r10=okyMtou{lI5`9yi)7Q4hpL8KD0tmo3!9nh|R?% zU$g4t-Y5fktFNvGvbH=~=@E<%>|2@XgFMcsOp`LEUrf^3M- zBEUrFTE%ymZv!$~*SKwPCw7NkLWUcC;bgBkrw%Krf*7@i!mG4X=8@i6kuyrI|eKIU_i=y$t%KluOF6&in1 z9+<`7C)k@u@5I=Pl3^kZeNWs-es1xDa5Ws<9ADZPv=_1#D8*5!1kaLL&P72LFCnLv zoXs;nZ=B7Tr}|^w<3(2{R<5!K1w109$h zu1s%F9K+1BO|#M&TjjyHYtC8%?kBwRWPw!7-ewv7$~3b}9no)mLc*p#@dcRCsv)DE z#yuZ{HbLOt@}z>iETiW4zNVWY7bi{U#K!r&HlC96W9g56tEjg&h}YZ8^8^y;#sG&Q zf9Qcyy zDxbyQ_#xE*^iv-#9R*!zT19V*=K}ofL#B?7_wd<-Yho3#KFp2&xb%n)=Hs{%)H!|i zzxNa-fia%BJu?A#D?k%mQ$BO(5SxX%qsaQc1T`a-LMzgYEiksvGR_AN&*ExGn*H$5 z2EUS7n5N$oDsZ16ijS2NA8eSW)|KKma9Y&IJI#Aeuu(mdq_ZjbcqPB7d~HWUQm8N( z5=<59k$mCvjPjsk{mjRZ`+4rV3$#l)TQvBCPv<30u$Y7lk^%=8SKQUU%oB|J(yyg! zY$!2sf!@{c@&^bgCMi#%NTXo5%HND0?t2!;vL_c~a=dqK7OwR2Dc&8}*nQPY`Funt z$cX%qsE!kJD)g`#0xUk4gzRL)e}j&VSG+i(rVw!-WIHVmx#(9T?VqD)($`gBstm?eZ01iUN`o zHihf(wpTvqx&9kYFQ|tAT_#qM!=F!64Q!L)zHXF+&vN12nO>gW8@V28tkhaQkO=5{ z{gq+YBFAhbNmAo*PlK$PvGld1(i{iCJygp!Yn|?d?^MRCbH%qum{IQ4T-eNsxJ6ZW z%ffeZjXXYpm@^2)Pu(?hAgezXLy3OhxnOWUF_e5+w4zUj$FD)C!*u|@0gbC^Mn(W) zf#uj0DX859ZxfWGu2>aLO?WvF71g8hm^*4wuoc`&@n^4R?6+JYsNlH-@L9FLTZgyh z<2qRvzfG&yl1|G`r51{Z71K6|df}$cTAlL?trRO#=^0=$?0`&AH-r4>p-rHwnHf}q zxqIb<`fW0unx06r231BqAj|>k>SR>YZ2=V=8MU**FPI9UqEN9}vB^hK==YpUSqv)4 z2pz85mY{8NyeJtUKgAb8o_-O<$GVnaK-M&kWm50?oZ>v+ z@8^UyEa<@E{NzT(7tz7R*kcmb0PN4LR$MT&et_1W_Qqqf5v5jOQQ690>&U=T{Wu(p zpzHW#EPKg)AQ(fk6diV_={L4KDuT-?Ep$Dd&URQf!?u`JPN5Tv`J%qN3#*Giwrh!6 zQU=;@6XG%FrMAByJ>dBatWRd?msBY>RnZM0TvA}o=Ly&?;!O0(4la_Yr7|9%9oW22 z7Uv*v!osa7*ebQH^z~Dm#S8d2U8$xrJE2xPTZn1YBz7rP1!ah7^L~y&&3Y&JrV5)z z=(_c{kF32BR2rHp8Cl=@!6WKAeT&+!<)L@CF7QN4vr=gWYh?54B6~P465g}tRS_~yBEo5!V>K9={4DPSEK z2O9^nfulZpKULdL5|)ZH*DY867haoQaqk%vN|i{{eMXj+Of>Hm)MQ?O!Z6&Lu!!f^ zYbZU@KW1;l$_9?=KoNTz`;kS4f$wHiPRHH*d9seRbb3#;eQrfIJm#1hvO_q_=u56W zn|?d1_x7O?%ZuaB;B-~0+1YN2H##wB{+f3RJ`;LqCWC4?I)|i^@h0UV4)d`cbmv!V zHp6*)2JJ;hp!@&J*^|olNfnDVdMtgKDd$zmPC|@S^e33pTH?Z$L+djp6%r@*t_azb5lR zRoI0K&?4@MA^E$oQ@nc+Db)zL*P6rnf^+m?3&23sWDY9&9BLG(4)Hsbe!5X`r$25~ z#nwvjWO%iVhO|Q0qM_Mi4&%#Y_Coprcnk#&570uw&d64cGRv+}zXg8eelV3&O$o+g zE?p|fk(k}ZKMdLg|8Q@(6}W9@GRd2`kF(BSl>(E zHxeY46tlY9TU#^e-26F}vcXWfK%n;FSgEwZU;+8*;Jp*vn!8OMzW~x;*2!I3b`$y* z^9^Ur7qS)2uqEGuaemf&)g;}t7CayyBezn8N6OngXRH+=p0_j%Y^#iIe0j%x4|**f zjRntA^qM{tDnKX8LL4oFq5AKP zzV;tGmk>tZv)TW;ndHXZoZh9f;b7_p%W2aX#e?emxKDapgy%G~g(^$6qhgoN#ebu* ze|@z{-jb53-ex>nUNKtOBqp1*VX|I0K~5g7n*Xm+S>@+cJku*0w-CGPM=q4kXq-I{ zuu9J*^=LMAK_IbeOo#w7w?%`33vd5r=@~76-X-f49p=VZSKcNs^E+h5?ZCD6<14O^ z|MjrWHKO+RP4P1*{LZ_}@Xx~^e>`Qu1B@*wP$eaGO>73xS zS4`ca=o^=l8GaTJTWz8Jlcz(?pOv#os?7fM^=?dK@CK+NHZ9qU;wx(%%6_dZk)YmE zy0Qc^rJ-^A~ZV`U|z9DE&j2M%2D-WnVPP`X<& zx;j(rUgJ4tPpKl{-Y|H*vH7+qy}jz_Ly0H+tZ;~qG68~D95a$3b>Q`(HK{)q5Rx3H zP=chfi}U=B+ja*9ev~!3H07>`?Su?dw46}rl&;Tw(Zs~aSq-`4kUGLN9i0P+vw&{m zVZT&N=0xgZp(J0O1|dDTZ&qL?fN0D11&@hFX4Lt#lvFt^Zo!v5p3IlGqZ1=n%j&vh znkQIWk2@V+Rks#8RXqo&`yfmWnIG^j+Sr+G9<}Reecfu9-0TG6MglUp(+i82+=uqy zN~Rosh^v7A@qYEN>kC&)wS)TUBbyDmzFaAd*k3hR z^r)dO=O*-hki0YHvv{%>|Am|Zv{JYqwKF?YSr%hq@8;duie(~F<5paVXq%nt$p2rrS7RNqnk6M)Ar zm+t}s6#IbO7a5I7^zxPFDNf8;oF`x#^?GIB>AhxIYNVzT4TXvB0srvnT&9?=HUJZZ z&H9pTLPeR28-?aqL)fC#lzt7BM%LY;kC=h1u%;fcRX4Tv$SO$G{<+8h8Ffuo}nT;Jc7h9A#pi|%3FvQyI>(^#oORME;t`g>jC7$wkPLL&|4`K z2LgBWFyeIPPzYVZXsWjx{?`&-K*LTV45ZDbFD?}Okq(86P?MGm19yWJ7d>Y+j4B|E zn{=%{Gh5J$sX%otN)u8HkPHrn8zOT1QsBoT`KYKSCRcGP)S^>{KZ0$%EiW9*>=VlA z2^$1j#{VR5gXaPe%ZEgkja&njy*K&5n1g4+@Ei8_m}(0-kfw$`?X;1p+k?%KYAXV$ zs$VTaRLsA9@&O?yf%J)t%%WHv;8jFZQa#_g;QX9c#FL#R^ZBp z^DsAZUQZFd+bhuGcJopVy7rhNC^COOjNO-mgNAeavSGJRBK2DaxcSfRxQnE)csO&J zE8;^Jnq$uxDw*x}cwati7*!1JLd*{p1h}+rXvVh!rma{%+%vE_h0xGgCZJCfG;~GV zxl)CF4zrCzqqt~rrxAS4H0qijwP^7+*zAn*D~%_nNEPrCptb4YDAi)oMC0wG>y)AY zHue}wQ`62KlW)ZpWj=|$8=zkS#PURX2>8~x^YP16)FyExGri(IWdHOW!S8;tzdiH|Wc^v)Up~bM6OW5S~(fnHsl7iSHU)s_6e$oq0h`t;B=4bTt1Wo|PE6 zJQg;+B01vTR+MD{(aTQX&|zBVfR4u@C8CA*h3{2c7om?%q^$$Ng4sZ=s-VYcY1g&3_xexr>!($EusvU|UV~id!m#@AMESy@DoS$}#juO{o%6jyJPnHlsp@Y>C;)LmTs^5k5)HnF;oOG^j90Q)F zZe5Tcckmu!%?o%&UDYX#dYiVZQT9dWk;^IRlUK1}AE;n!l$$`r-+HQQx=navRNBZ`1w6`)X zFG2?A`vrEGvt=;<;>}Mzp7}+|dGQjiMipS!DFVz8Uz8_s_KR(3#4~EpLl}F zOl+yT%y+#D5ogc!1jNk5u5s<4xJsc^d`*j1I?IHd1WzDY89si5^&7nG=hu(GDt$&5 z$7ml^y@)`<9!k9dL)I3^U!x~89#k2{ju?NasT={v?um5RhXs~_<6*Q9{ha-#biTT0 zhx3-3_yVS)jf_DyK?9+d>@d@ZekPn@x?Af|^;Oig%d=P(nqRE4D8i19ZIRXGs8Q-L z3krRrRU4u^)7#&m2RxgEM6{QJTX2tch{}RsY6aenm1Fi&Y(M?BDL;2;b2;`V)zGKY z$KcKhWp>5kXZ)b+O$$5h5B_}ha&mBOy+j^%3BSaYLivULCnW4;eL6YqcSq6%gMEvN zvf`O4`(KzqDYlBc{c@lQqpco@#EIN>$y;4vB z>F+nS9%#O~J~0&c^=9ovmAr8A#k-A?D#I;SnDWNp@GzM)|! z_jqFZa}h8_0aRuW`;}hh_%_?c&oH2fD9Qh%6^=!YMwVxRX)2sY+`Qy{h~(fGG;o`H zYXTXWfb$HH_+^KkW0_XJF|`v^%YeTF z&v$)SRvHzN=G^{$_p!&DISGInOt!=#ruq_!n@jyd>1H$Edo3utu+~G9@%u9SA;(Yl zdX&I*YIRNK_iLF8>z2i(k#un)2@Qr*(E(Tt`SrC&33b=;NVR7rH$DmxgMkJ0KFGML zh!)P>f>iW^Z>MWsz0vUi0p2R2MXlO?iAJbq@7?M6@uvvT-j8!*kims3&{^mJC7jGu zj%Mg#cRU?`|NiPUms)lL!e3BMfuU|=?vG~hYRMCBDemJSNA=&^w z`jE=9mrgSTKl%w?xK+)1)0n7|b%dvN@%q$Wrx3F6$i*hOW29!45Z_&Ey%gq7k!IpxdQZUOsu zE1{FKGpm$i#`W}1%g~@yQkqlf^8pkGW@FYb;kJb6#U17??}C|8h`>UxF+^QJi&6~~ z^D=#ePtVOhy}+|IVurNjxYxt3@#!oCQGkc$CHHuIoDJBmCIrD}rFe9kXs~ z7!IK=4lQ|v=sMTohBZkFZuA7y0$a1j^Pd%2ZYb17l?+3D)Xs5R??SOT)sWL3Bkb_1 z!>$(+;QQf7-{N7r=7L($glUWCVoxx+nEvDl1NOisdE+TREBA%Y{+GavzTYXokaRhi z1n&s%z-?XJltY(DA1fu{Oc`!#Px~&O1IR85G-c;{kY0!}km2;(&b=p+%g@-@5VKMO zH(&N5q8mowQO&?nuWJAYF3q$%s1W9yY+boPjgOgVo<&vRlh~d}e9U2x4sWvO)C}$z z_Kw=MqmWbeJho#GC=}wN7pNQLVM|ZwyiAdah?d9WMI_>*2LOi`DTLhr%hoEzY`J#b zT!%1QtyL$^>V{jtim_?whwqsosI$&>IRY>o>$r=iuBwooPffGu#lbgdjFvt_6wKsD zwlgX6^L9==T#iYFl>USM#)eRtM)Gx&F>~y+RwI)DXmXTd9c!sG6@kF&cpX+P94c@tb zu2dO`SO~j%s~V=1-J8A%S&X1ym7jQ*iiHt+{CY@`Spi*aMofVNl`ze!gFwH`C)NG4 zQc>oDt@$H5GTs4QCR&oNk+&zJU{TiOavTg58crb<`%qDlxv{f(glk{%tUr!Ai(UN| z6}ci4R<#WzNmNy_;H;MW0ar&W02{jzba%qYJiFlQRVtkKdV7sbq;=+y9#v8DyvW&i zp0Rkoy%m5U#+9PL6)Km<(m#3*i-yL#6B*Mb6c@J$pyyS*=E>b&YmW>AzKR_TjR|`l zbQa9^s=>O&m(*h`{!gwMhgCfP#FJKN;>-cf*BuRYma!BNha@`l)dgM+tZlptL(w04 zzJNU5>oqm99=8V?rz~;+ld|BC>x7fhA;_gcu*9;EqZ40H3C+l8waXpTdlK?!(CN@3 zj<35_mfP}^?+^}IUH z-K=u%SYm6d7>&pL-Bd^yIylP$l2~)V+qB8Fpp0+P7DgIdiDqaur*<=?y_%I1#)9TM zn>B%(RIIKVsGpB{;;9Hi<;wCTj^p3U?*11UHNq+~F>q&VNtKu8{kF6m?aRciTaM{z zmTZ-F-t?Ds=w0aW?MrETjyvsiz5BB?yejBzX9=(-K8RsPc?*GUcpn4-VvSPVl=+50 zqp00HyaS8|fnrz^F5W=1RH9+*ME17I19r*0|B}z9QTYbfG(wW;Uet`f6|7A0Ch%O7 zA=)eGgV_3^04>NbQ!teGUWZtp{3=X-F!9TgHI=yL^2o@AFke zi9MTjC<3%$&1>B-7MXCr-e|-gZosXLjv)Lyy_SuRx~&fU)e?wm6?&N#%e|p7QOOBg zx)+>c*DBUDNEED>j4BsW<{Lil`vv}|1g91gf+~!S#Xrj|T-;byEjQ5kS9Y70&O9bY zbK`{vQV7Ikkh_D~9A)tl_mPUJjD!x)M&o*qhYcei{N|MUePVgxtzC_FtG^zCc;;z8 zY1%eM3c|e9##C4dzQ!$RGq*z$x*s_NjRLa{8L;R$wg_H2Kzn7JpBbQ3aYN*R5pU`4aNh>ufzINq^8s~WXqLqHQ0Ye?!~ zd&#hc|ANq32?6KjdlUD~Ewx}zrY2|I=Ai|?Bo*RfZjlv;`Lf*j;;e04F|&6`%~@!7 z_+N_*4L#idjlM69!qZ&1$VN{0iGdcX*@mi=2Ww%2t0dlWO3%~0czx?-^t-uQo8AiZ zBTMWVk>J?!9hk_b)#pbLu7NG=z5X_vq83xY7|_NuGHgI+R3KZM?xp=8_MHCOmSY@- zqfR`|ui1wzoZAS~1Zk{tP4jS?4F)XDFFy-FkFWt5^q|kApd+>Pf6x}tcY$YrSvdJ? zO5S%E;tYk@`ILKI$lv4buu`xT{lfR2w6woYJA}W zDPgLRaAg8Idh-8Oj=OW=ME0CUyU>T@4248_=XGu+IqFTu{0DWbzatoKdgBTf7da1= z64k6GX^AIIE5|aSRckqSHk=uvmk~46L@DIYAPFK`FE`sY;;mQ*3}u?>Qo{*YP>khj zZFc+B_c)DW8mNEMg|@@RV}zV@(aub~2x;=-n%<$mbx$ZV1)`sT77w zf4GKP7P3x@AA(~hD@_N_X3HuRLj!g;L4~MfZ1HFo@BPM3#Wc{}Yv(`>b`lTWH6!OE zggNJXe>mBLgiZWwFkdks^ruK%0*9wvt5E7xxQ$D3GOV9!@-82^NOXtoF z7lh%{q`doc2%3T77F5*aKfvOGIsjm~2YkFhWothWz3MD#Zt17|%{2Ro>4z(kXoH0) zo|*-;aE;A@vi>|CveRW0FoBat_ktL)1t8OAAQos*^N5wq$~;EzyWJ8t4siL`UEqRQ z#0A$paoAhWxHJ84V5i^#b=X=T*0WvH*v<6$bihZKN}9qUQ)s(o*)VB&oWj?=UakR3 ziV6g!??=uTJmbL;@S&k!ZALUZ?cad+1ec>3QW-sb9`hHyT+`_Tn;@Ap3H-sY5OO7| z1Zl3PSqfKfGY1`qf)UIBzG4sms}=w(i|#>G`8ZL=m9~6POJ^g_DRQV&b}M%I_!*9S zqdTervw3WetDHscfKSuti*7_tVnI(f5bbt&$o-yBlB-c(N8P>c{)u{rzt@=~-ziYY zRfyM3QrBBO5G%WbG804^I`B|U^#!I>{XMb7-6C&UM|fXFs@AxVQ{4nb0HL=2U}*kDBWJJ_r8mX}AY_HVAa&ON-au2J^d>{n}^Gs1U3SM4g=sapAWz@Dd zuD^>JMKO8zL`Z@7P-0m+^GHx)Pol&V4SvvVH7L5u=^+Kp{{|BQ(JH@m4&;VeLcRj< zkpPUzTp4e1gdKI?wxcF3?zC#Pg5ijX%Ce)tipCDC=V!<*8ZzPsMdKyEFe<87)~V?6 zz`TeFEx?d1JAt{5HbB?{D_2=Z?d!;TbQm7h=|I~8AjP1%5x=t;PF=Bwby)Z;5nAIi zAp4p?cX`nBzP&Q7*-)}V79TOAKD^X#cF`w7WMRf0Q2$cV)4(OMyX($jEosf6aG%yZ zmIZXQ1ODMiON;~+D*(9_a!6q7+}oj^+kEq5`szz!>>3RSOTpO+iGntVaEgB#Y)JLj z1l{PDeXJ7s!g$+#u{(IRYBl6!(SC2}D?c2f6O3iNQ^VxtGoO0PLC~(3PEtBS-ufzY zHR_IT2I8(2MCIg`jO`ERvzb{)btVQ|Q$@ZoSu2^A=FSa3Nm$NFC+k21Lwskkam&nC z_+iirhh>n92sujB#1#Gh&A~5R4)tgYBp(|*kmg=Pl=)Y(w5f+_KQv8umWs2DDx}LS$Fr*BBR4w z8aP-hX=v0eI0g>O+rPIg`UPXVga1CxP*yi_&*j3u62gMee^yL#y1g-}Co5)sqJ1>w zurC{y5Z*P(zgR^T@m4);v^1yC7h%PNN|<$Uc$K6hJ@l4Z<9hGrPavyx^nVUa3{vPj z+lV~Rc{|&x{_O|}b`{82-nnyt4N7-vn$3UY#soFc)4No8bmAM4dGHkED~$U z+Z(uOk;mza62it~R=Z)_U|z?a|1rqjrhJ}TVE*_1+($K%<6HgMM)`=+R65lX?`qNe zYt%p^vjcPPLg;ufszt>ZEmKM6-p1CZ&OXn?bU+UR?{e$R72MnLJL*6N`&HU?^Ma1d zRS21nbwW(I7Cgi@!D@v41PBRf4A@%88E7`1sLHV z`}5EMjSQs$!52Uu*aaO6I3d>Qb{d;`D`h%Uzbj2gtqFR1#jaQzn*oH8HQ<)uqyvln zY(m)im(2n>6uE3ylNAG?2FRf0_Rt!W%O7%uudKI>EgZ<9+6~olGvKrn)|LD7UnXY+ z!&795O>~K$l>hzzYEzHhac+JhX-xho>3g&erddDvW7RbLmnkC z)u0l-SBQjw>7XFx<8$+`POeM_m%~!Q4>sM*zPz|bl0(NY_iiW$NsXHKm@Ah9*U?2t zGt>2b$3Nhz8}0|dSxPsWXqX`5^5>O84AKg`rW_XxB(lDZC$2QMLMySHr1t>T(XJ6y ze$R}$m>Nmw!)EV>M;AGBV4MJ{l@Yn4sI##16vxV5^w7iVOnXGinDu1RUh_M76sEBe zP>9qK#ph!m{Uxg_&PQjIIQ8p%$(s9+jfW)T5qDZXx~a1o%5@7dr_-j0XX5d)!Hr7R z%YFp`%=<6pSKtCB;rCSPwwH?`?IHhM4!b&>(ZpKd3%M|w$FFz5H!M*XCZRgQ8`c|= zK;Kw+>h>2cK&fiOP~a*d?7cc_Lh0y=>&ooQXYDK8=^!o%3*c74`RZBn7j!jHv;X#a z17p!(z`Jh8V+xRSNO7Awaq*z?pV43M`T9H z1L0B9HU`TIJEWqc)fP2M2r$b7zf2$TN@crU zfYJw-NCQ0-w16^RjodC~S>Vj~3md0w~b zB80s6d>RY0BE$?b;dB1yJ1$~{Z-cJ;#J3!2eli68BE#o!# z*`QO&&xKU%I0d{jqMt&$Ob=LWv1BWCVys6}E7LeUDarC(m=IVJ+$L5eev!i^tORK%Dg zdt^`GtYz^)9LQ$9idfNS>+Qvv$a!z|Ha=v6X;YN`}6)H2S{(QYH+c|SG~pty;H6(*F^nlwb?`Y zvNzCa>*fN${P5X&f)Kld#R10*PZ04YdEr}_3V^5^nsP9Ax%kw4#I@I3kzE5P!YjYA z-j_B1uc)h!YU;eh+2|>CL(=+C3qrSG(xJNu?C1SP$1tQ_g?qJfA%MwB=_Te z-{0?fpZEDaKRhU-)e7JG?ej)Z$ELxB-DD3V;p;m@`Ky471i76yj{dPTnt()3|1xsSRJy9@=uA-ZccZ5dvh!CA}k{ zlt+4T^F9DDn6|&#J>_{oEnyQ*Yk;D{Wh(X&s3_{G{si&x`INfSL2Il ztP-~{AIw;^y|~+*iC3B*`G-vcBl@=Ul3sRiPFnq!w3`b*U*w>>+R1L}(feqoDC6D1 z;plo~yDM+x?XCjxWN>@UhJ9AHb#eAbzOSL&2n)S@Pgz-LOUp+2#I?eS1e&#*b(9VJ z*3$i@1a-OhB8IB6F6k>#HTs(m8vgrmfnfIibr4d)zA=l=GhCe!P8fA#7xGjtsiRC$ zFQz<_I=iX|T-^Kb`A~x%ytTY`--%j|@ohUd^3rPV*++Shs#!q^69R3Rsw*h8U%_5Q(`47V zFmLzo(7X1vKByR@$8pnB6+`y&S8;Er?*ch7#G_mWQ1#NE(7arsxghtlmW z6KLCSRw;fx9uy3Gkr~e{P5mq;0J>=MqN?}0rY(-OaA_lo9sHcuC@Dq`LNxGCU6g2O zWR=$6y!o*r?=iWRA~44g&K_2*9&DU5*;x@_CE<4jls6YPx>aAn+;A;mn>OHW6~83Zp;%oXdtZ1NKg z$`BiCr~VPIYgkc3nu`SLh6+bH#M^iK0H+}Jg&!jJ6h^#k{aww3_M1=Mz^hK zGnqWMo$P_Dbz@@XscwDx}FkxcZ(Vd$AyPtJ#Ctef8d+p z)1WevM0lVKAPkSkKIU?S=1NvhRkAaycEYdJHeF9|C<+KYDQD*+Cd@yGfFtUMNRlf} ziIhY;t6Q7_7ya6O7=I{kjnk%!D>08|Em3W@#|j~P(c8P=Nutbw4F$onJ$D9&49D$3 zp9RusrqJPZq{q!U^!Mo&{Ak{zkE4Sj@4rq6((;=N=29M
OXe%ge(VcMUeT=0;*~kLJZo!w?1>sE+`zI>sIF`rG&u5|j;z!{zqK${bQR#vb2! zoxzE!r&#z~L^e-MZx>pNg*xG=RhxE@wU1V;{{HC9KXA}_I>*}`UQWA*#q@@e((>SB z7ennKn2)vhJKElmW!+s?fx8z}$xflji$-Pos%h#mWLP-_QvOir03=fTq0=uQ%U0}u z^+9;!H4846;xu+VkD{fwt>Bt@+0KwU+3(bjw4=7<-O|*1be6^H95_REQm7>d#`uZI zu3LTOmCAnFQ+)Kk?o)LKwE4ZPO#PJ0rgK$k2Zl<|RTj$Jd}9Akx2rA*4BD!FzUTZE z@!gf>nt&IJs$MU|`7pJ9*^goZMtMFv%4LJv~V&AaI;2Re~~e6c9US7Fiiw;LEo&!;TR*gDpE?z5BCn2 zAN$C6ZlN{Er^+4gG-6=_C+fGs@?9Jp+`IM8cY{e>_ZZH;$-sjTHXR+HFZ}Qc#%EiM zEfj+ie@HlB&CY*Sj!F7}W~t$WL)CT2T2Qt^%|VC8i?SD&&^=x4?A>C_pnoEq^M(Wu zjE4@(JDKUuH#r|=Clj7)B6(l4RiUOC;Zt%943Tyu_l39QzqpEU18&TB<4+#HpJE>u z=Pdr1jnT3G@DltK>KD%{wwUiJ+%{{-3k7So$o~B&eXq%@%xFS=@GDC8LR(;7=>E}A zG7|QRmIf*fmIlg?e=EHIjj&UM1RA`RqVH~NmK7;F7#+foz?J)*UQ-<1CXS++7> z0MnRxQW{{9x_|O$nO;kI7a>vmW{o<$a?k1M5?@hHkPFGdAsMN+H0o1vbA8zLF|!eY zaP>ZZBV{;*{x5xk?2>A=Io=aej-qV8bl>vlX$zBOp1+Lks%5x=3iySs$Y$ztKrEmJ z3ZY#C?SXRpAD_1i_kljZyNTCZyUp_3=g=E1yf(||Wi!V$>)+pNKjDv>wtD@wH2a~5 zt@8Jspgpi%$MqYx95P-1!p1UV9IukPy%b_y;ofLhBr+^hX#V==LMZ8F$Q_d|9~|7fv3Y!5IPf|jDQ!=gy$fXcI*tHya<_C-gz;awl9QV(_Cykv7&Z*;{Zi+E^rr4pQ5v$~o#__2#bQ4O1m;4Ec(EuTX`>8wnAmp?@ zwzK!ct9QkpNey7WZCrYaqRExdLK!}eeSnEUTcvOt&|8{~WPXc2J?(=rq*tI1(M>*X z;TG&YN}XdAd?ZrUjrkR;IR`gO*kos@EF|>s-*gWP{&|76cu>iNkJ7Z=Hbx!7^M^dD zT6!)2@~IgMV*LLbPCiPF zc_L~LTAh%CFsPl#ZjV_?d87FRTra7wb|k-07Ec#l>!q`DnJl%mYgskb1-UXKd Date: Fri, 9 Feb 2024 14:07:09 +0100 Subject: [PATCH 09/23] add test model --- ml_mgie/tests/units/test_model.py | 55 +++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 ml_mgie/tests/units/test_model.py diff --git a/ml_mgie/tests/units/test_model.py b/ml_mgie/tests/units/test_model.py new file mode 100644 index 0000000..c6bb036 --- /dev/null +++ b/ml_mgie/tests/units/test_model.py @@ -0,0 +1,55 @@ +import torch +from pathlib import Path +from PIL import Image +from ml_mgie.mgie import ( + DEFAULT_IM_END_TOKEN, + DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_PATCH_TOKEN, + MGIE, + MGIEParams, +) +from ml_mgie.utils import remove_alter + + +TEST_IMAGE_PATH = Path(__file__).parents[1] / "data/0.jpg" +TEST_INSTRUCTION = "make the frame red" +assert TEST_IMAGE_PATH.exists() +params = MGIEParams() +mgie = MGIE(params=params) + + +def test_prepare(): + # Prepare prompt + instruction = TEST_INSTRUCTION + image_path = TEST_IMAGE_PATH + prompt_ids, prompt_mask = mgie.prepare_prompt_id_and_mask(instruction=instruction) + assert instruction in mgie.tokenizer.decode(prompt_ids) + + # Prepare image + image = Image.open(image_path).convert("RGB") + img = mgie.prepare_img(image) + assert img.shape == (3, 224, 224) + + +def test_generate(): + # Prepare inputs + instruction = TEST_INSTRUCTION + image_path = TEST_IMAGE_PATH + image = Image.open(image_path).convert("RGB") + img = mgie.prepare_img(image) + prompt_tensor_ids, mask = mgie.prepare_prompt_id_and_mask(instruction) + with torch.inference_mode(): + out = mgie.model.generate( + prompt_tensor_ids.unsqueeze(dim=0).to(params.device), + images=img.half().unsqueeze(dim=0).to(params.device), + attention_mask=mask.unsqueeze(dim=0).to(params.device), + do_sample=False, + max_new_tokens=96, + num_beams=1, + no_repeat_ngram_size=3, + return_dict_in_generate=True, + output_hidden_states=True, + ) + out = out["sequences"][0].tolist() + out = remove_alter(mgie.tokenizer.decode(out)) + assert not "Pres flash togful calledgot" in out, f"Nonesense: {out}" From e08043375c777be8b1102b9e57b00de4b60e56a8 Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Fri, 9 Feb 2024 18:30:22 +0100 Subject: [PATCH 10/23] rm old llava --- ml_mgie/ml_mgie/llava.py | 605 --------------------------------------- 1 file changed, 605 deletions(-) delete mode 100644 ml_mgie/ml_mgie/llava.py diff --git a/ml_mgie/ml_mgie/llava.py b/ml_mgie/ml_mgie/llava.py deleted file mode 100644 index 19262e6..0000000 --- a/ml_mgie/ml_mgie/llava.py +++ /dev/null @@ -1,605 +0,0 @@ -# PURE LLAVA FROM ML_MGIE EXCEPTS REGISTER EXIST OK -# modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/model/llava.py - -import os -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from torch.nn import CrossEntropyLoss -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - CLIPImageProcessor, - CLIPVisionModel, - LlamaConfig, - LlamaForCausalLM, - LlamaModel, -) -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) - -DEFAULT_IMAGE_TOKEN = "" -DEFAULT_IMAGE_PATCH_TOKEN = "" -DEFAULT_IM_START_TOKEN = "" -DEFAULT_IM_END_TOKEN = "" - - -class LlavaConfig(LlamaConfig): - model_type = "llava" - - -class LlavaLlamaModel(LlamaModel): - config_class = LlavaConfig - - def __init__(self, config: LlamaConfig): - super(LlavaLlamaModel, self).__init__(config) - - if hasattr(config, "mm_vision_tower"): - # HACK: for FSDP - self.vision_tower = [ - CLIPVisionModel.from_pretrained(config.mm_vision_tower) - ] - # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower) - - if hasattr(config, "use_mm_proj"): - self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size) - - def get_vision_tower(self): - vision_tower = getattr(self, "vision_tower", None) - if type(vision_tower) is list: - vision_tower = vision_tower[0] - return vision_tower - - def initialize_vision_modules( - self, - vision_tower, - mm_vision_select_layer, - pretrain_mm_mlp_adapter=None, - fsdp=None, - ): - self.config.mm_vision_tower = vision_tower - - image_processor = CLIPImageProcessor.from_pretrained(vision_tower) - - if not hasattr(self, "vision_tower"): - vision_tower = CLIPVisionModel.from_pretrained(vision_tower) - else: - vision_tower = self.vision_tower[0] - vision_tower.requires_grad_(False) - - if fsdp is not None and len(fsdp) > 0: - self.vision_tower = [vision_tower] - else: - self.vision_tower = vision_tower - - vision_config = vision_tower.config - num_patches = (vision_config.image_size // vision_config.patch_size) ** 2 - - self.config.use_mm_proj = True - self.config.mm_hidden_size = vision_config.hidden_size - self.config.mm_vision_select_layer = mm_vision_select_layer - - if not hasattr(self, "mm_projector"): - self.mm_projector = nn.Linear( - vision_config.hidden_size, self.config.hidden_size - ) - - if pretrain_mm_mlp_adapter is not None: - mm_projector_weights = torch.load( - pretrain_mm_mlp_adapter, map_location="cpu" - ) - self.mm_projector.load_state_dict( - {k.split(".")[-1]: v for k, v in mm_projector_weights.items()} - ) - - return dict( - image_processor=image_processor, - image_token_len=num_patches, - vision_config=vision_config, - ) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - images: Optional[torch.FloatTensor] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - # HACK: replace back original embeddings for LLaVA pretraining - orig_embeds_params = getattr(self, "orig_embeds_params", None) - # if orig_embeds_params is not None: - # orig_embeds_params = orig_embeds_params[0] - # with torch.no_grad(): - # self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - vision_tower = self.get_vision_tower() - if ( - vision_tower is not None - and (input_ids.shape[1] != 1 or self.training) - and images is not None - ): - # TODO: this is a modified multimodal LLM -- Haotian Liu - with torch.no_grad(): - if type(images) is list: - # variable length images - image_features = [] - for image in images: - image_forward_out = vision_tower( - image.unsqueeze(0), output_hidden_states=True - ) - select_hidden_state_layer = getattr( - self.config, "mm_vision_select_layer", -1 - ) - select_hidden_state = image_forward_out.hidden_states[ - select_hidden_state_layer - ] - image_feature = select_hidden_state[:, 1:] - image_features.append(image_feature) - else: - image_forward_outs = vision_tower( - images.to(vision_tower.dtype), output_hidden_states=True - ) - select_hidden_state_layer = getattr( - self.config, "mm_vision_select_layer", -1 - ) - select_hidden_state = image_forward_outs.hidden_states[ - select_hidden_state_layer - ] - image_features = select_hidden_state[:, 1:].to(images.dtype) - if type(images) is list: - image_features = [ - self.mm_projector(image_feature)[0] - for image_feature in image_features - ] - else: - image_features = self.mm_projector(image_features) - dummy_image_features = torch.zeros( - 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype - ) - dummy_image_features = self.mm_projector(dummy_image_features) - - new_input_embeds = [] - cur_image_idx = 0 - for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): - if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0: - # multimodal LLM, but the current sample is not multimodal - cur_input_embeds = ( - cur_input_embeds + (0.0 * dummy_image_features).sum() - ) - new_input_embeds.append(cur_input_embeds) - cur_image_idx += 1 - continue - if vision_tower.config.use_im_start_end: - cur_image_features = image_features[cur_image_idx] - num_patches = cur_image_features.shape[0] - if (cur_input_ids == vision_tower.config.im_start_token).sum() != ( - cur_input_ids == vision_tower.config.im_end_token - ).sum(): - raise ValueError( - "The number of image start tokens and image end tokens should be the same." - ) - image_start_tokens = torch.where( - cur_input_ids == vision_tower.config.im_start_token - )[0] - for image_start_token_pos in image_start_tokens: - cur_image_features = image_features[cur_image_idx].to( - device=cur_input_embeds.device - ) - num_patches = cur_image_features.shape[0] - if ( - cur_input_ids[image_start_token_pos + num_patches + 1] - != vision_tower.config.im_end_token - ): - raise ValueError( - "The image end token should follow the image start token." - ) - if orig_embeds_params is not None: - cur_new_input_embeds = torch.cat( - ( - cur_input_embeds[:image_start_token_pos].detach(), - cur_input_embeds[ - image_start_token_pos : image_start_token_pos - + 1 - ], - cur_image_features, - cur_input_embeds[ - image_start_token_pos - + num_patches - + 1 : image_start_token_pos - + num_patches - + 2 - ], - cur_input_embeds[ - image_start_token_pos + num_patches + 2 : - ].detach(), - ), - dim=0, - ) - else: - cur_new_input_embeds = torch.cat( - ( - cur_input_embeds[: image_start_token_pos + 1], - cur_image_features, - cur_input_embeds[ - image_start_token_pos + num_patches + 1 : - ], - ), - dim=0, - ) - cur_image_idx += 1 - new_input_embeds.append(cur_new_input_embeds) - else: - cur_image_features = image_features[cur_image_idx] - num_patches = cur_image_features.shape[0] - if ( - cur_input_ids == vision_tower.config.im_patch_token - ).sum() != num_patches: - raise ValueError( - "The number of image patch tokens should be the same as the number of image patches." - ) - masked_indices = torch.where( - cur_input_ids == vision_tower.config.im_patch_token - )[0] - mask_index_start = masked_indices[0] - if ( - masked_indices - != torch.arange( - mask_index_start, - mask_index_start + num_patches, - device=masked_indices.device, - dtype=masked_indices.dtype, - ) - ).any(): - raise ValueError( - "The image patch tokens should be consecutive." - ) - if orig_embeds_params is not None: - cur_new_input_embeds = torch.cat( - ( - cur_input_embeds[:mask_index_start].detach(), - cur_image_features, - cur_input_embeds[ - mask_index_start + num_patches : - ].detach(), - ), - dim=0, - ) - else: - cur_new_input_embeds = torch.cat( - ( - cur_input_embeds[:mask_index_start], - cur_image_features, - cur_input_embeds[mask_index_start + num_patches :], - ), - dim=0, - ) - new_input_embeds.append(cur_new_input_embeds) - cur_image_idx += 1 - inputs_embeds = torch.stack(new_input_embeds, dim=0) - - return super(LlavaLlamaModel, self).forward( - input_ids=None, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -class EditMapper(nn.Module): - def __init__(self): - super().__init__() - - self.llm2hid = nn.Linear(4096, 512) - self.query = nn.Parameter(torch.randn(1, 77, 512)) - self.mapper = nn.Transformer( - batch_first=True, - norm_first=True, - d_model=512, - nhead=4, - num_encoder_layers=4, - num_decoder_layers=4, - dim_feedforward=2048, - dropout=0.0, - ) - self.hid2feat = nn.Linear(512, 768) - - def forward(self, llm, emb): - hid = self.llm2hid(llm + emb) - hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1)) - feat = self.hid2feat(hid) - - return feat - - -class LlavaLlamaForCausalLM(LlamaForCausalLM): - config_class = LlavaConfig - - def __init__(self, config): - super(LlamaForCausalLM, self).__init__(config) - self.model = LlavaLlamaModel(config) - - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - self.edit_head = EditMapper() - - """self.scheduler, self.vae, self.unet = [diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler'), - diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae'), - diffusers.UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet')] - self.vae.requires_grad_(False) - self.unet.register_to_config(in_channels=8) - with torch.no_grad(): - conv = torch.nn.Conv2d(8, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding) - conv.weight.zero_() - conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight) - self.unet.conv_in = conv""" - - # Initialize weights and apply final processing - self.post_init() - - def get_model(self): - return self.model - - def get_vision_tower(self): - return self.get_model().get_vision_tower() - - def get_vision_tower(self): - model = self.get_model() - vision_tower = model.vision_tower - if type(vision_tower) is list: - vision_tower = vision_tower[0] - return vision_tower - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - images: Optional[torch.FloatTensor] = None, - return_dict: Optional[bool] = None, - p2p_inp=None, - p2p_ans=None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - images=images, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model/pipeline parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if labels is not None: - llm = [] - for i in range(labels.shape[0]): - try: - p = labels[i].data.cpu().tolist().index(32003) - 1 - except: - p = len(labels[i]) - 9 - p = min(len(hidden_states[i]) - 9, p) - llm.append(hidden_states[i][p : p + 8].unsqueeze(0)) - llm = torch.cat(llm, dim=0) - hid_edit = self.edit_head( - llm, - self.model.embed_tokens.weight[-8:] - .unsqueeze(dim=0) - .repeat(labels.shape[0], 1, 1), - ) - - B, DROP = labels.shape[0], 0.05 - - hid_null = self.edit_head( - torch.zeros(B, 8, 4096, device=labels.device), - self.model.embed_tokens.weight[-8:] - .unsqueeze(dim=0) - .repeat(labels.shape[0], 1, 1), - ) - - with torch.no_grad(): - lat_ans, lat_inp = ( - self.vae.encode(p2p_ans).latent_dist.sample() - * self.vae.config.scaling_factor, - self.vae.encode(p2p_inp).latent_dist.mode(), - ) - lat_ans, lat_inp = [ - torch.from_numpy(lat_ans.data.cpu().float().numpy()).to( - lat_ans.device - ), - torch.from_numpy(lat_inp.data.cpu().float().numpy()).to( - lat_inp.device - ), - ] - - noise = torch.randn_like(lat_ans) - ts = torch.randint( - 0, self.scheduler.config.num_train_timesteps, (B,), device=noise.device - ).long() - lat_noise = self.scheduler.add_noise(lat_ans, noise, ts) - - prob = torch.rand(B, device=lat_ans.device) - mask = (prob < (DROP * 2)).reshape(B, 1, 1) - hid_edit = torch.where(mask, hid_null, hid_edit) - mask = ( - 1.0 - - ( - (prob >= DROP).to(lat_inp.dtype) - * (prob < (DROP * 3)).to(lat_inp.dtype) - ) - ).reshape(B, 1, 1, 1) - lat_inp *= mask - - out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample - - loss_ce, loss_edit = loss, nn.functional.mse_loss( - out, noise, reduction="mean" - ) - if int(os.environ["LOCAL_RANK"]) == 0: - print("loss_ce:", loss_ce, "/", "loss_edit:", loss_edit) - loss = loss_ce + loss_edit * 0.5 - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "images": kwargs.get("images", None), - } - ) - return model_inputs - - def initialize_vision_tokenizer( - self, - mm_use_im_start_end, - tokenizer, - device, - tune_mm_mlp_adapter=False, - pretrain_mm_mlp_adapter=None, - ): - vision_config = self.get_vision_tower().config - vision_config.use_im_start_end = mm_use_im_start_end - tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) - self.resize_token_embeddings(len(tokenizer)) - - if mm_use_im_start_end: - num_new_tokens = tokenizer.add_tokens( - [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True - ) - self.resize_token_embeddings(len(tokenizer)) - ( - vision_config.im_start_token, - vision_config.im_end_token, - ) = tokenizer.convert_tokens_to_ids( - [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] - ) - - if num_new_tokens > 0: - input_embeddings = self.get_input_embeddings().weight.data - output_embeddings = self.get_output_embeddings().weight.data - - input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( - dim=0, keepdim=True - ) - output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( - dim=0, keepdim=True - ) - - input_embeddings[-num_new_tokens:] = input_embeddings_avg - output_embeddings[-num_new_tokens:] = output_embeddings_avg - - if tune_mm_mlp_adapter: - self.get_model().orig_embeds_params = [ - self.get_input_embeddings().weight.data.clone().to(device=device) - ] - for p in self.get_input_embeddings().parameters(): - p.requires_grad = True - for p in self.get_output_embeddings().parameters(): - p.requires_grad = False - - if pretrain_mm_mlp_adapter: - mm_projector_weights = torch.load( - pretrain_mm_mlp_adapter, map_location="cpu" - ) - embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] - assert num_new_tokens == 2 - if input_embeddings.shape == embed_tokens_weight.shape: - input_embeddings[-num_new_tokens:] = embed_tokens_weight[ - -num_new_tokens: - ] - elif embed_tokens_weight.shape[0] == num_new_tokens: - input_embeddings[-num_new_tokens:] = embed_tokens_weight - else: - raise ValueError( - f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}." - ) - - vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( - [DEFAULT_IMAGE_PATCH_TOKEN] - )[0] - - -# AutoConfig.register("llava", LlavaConfig) -# AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) -AutoConfig.register("llava", LlavaConfig, exist_ok=True) -AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM, exist_ok=True) From e6783e8d396933a417a4c0a4cfa133f632089d3e Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Fri, 9 Feb 2024 22:12:41 +0100 Subject: [PATCH 11/23] more refacto, add cli --- .gitignore | 4 +- .gitmodules | 3 - LLaVA | 1 - Makefile | 2 +- ml_mgie/README.md | 25 +++- ml_mgie/demo/demo.py | 51 ++++++++ ml_mgie/demo/inference.py | 88 ------------- ml_mgie/ml_mgie/llava_conversation.py | 1 + ml_mgie/ml_mgie/main.py | 62 +++++++++ ml_mgie/ml_mgie/mgie.py | 175 +++++++++++++++++--------- ml_mgie/ml_mgie/mgie_llava.py | 2 - ml_mgie/ml_mgie/utils.py | 42 ++++--- ml_mgie/tests/units/test_model.py | 24 ++-- mypy.ini | 8 ++ pyproject.toml | 14 +-- 15 files changed, 304 insertions(+), 198 deletions(-) delete mode 160000 LLaVA create mode 100644 ml_mgie/demo/demo.py delete mode 100644 ml_mgie/demo/inference.py create mode 100644 ml_mgie/ml_mgie/main.py create mode 100644 mypy.ini diff --git a/.gitignore b/.gitignore index 58f9a56..07e2b10 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,6 @@ cython_debug/ venv* .DS_Store *.pt -*.tar.gz \ No newline at end of file +*.tar.gz +*.jpg +data/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index f1b7c9f..e69de29 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "LLaVA"] - path = LLaVA - url = https://github.com/haotian-liu/LLaVA diff --git a/LLaVA b/LLaVA deleted file mode 160000 index 7ace501..0000000 --- a/LLaVA +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 7ace501183c4bdec6052ec1a30039cdc3242a67c diff --git a/Makefile b/Makefile index 60f1d8e..736ef6a 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -PROJECT_NAME := ml-mgie +PROJECT_NAME := ml_mgie .PHONY: venv venv: diff --git a/ml_mgie/README.md b/ml_mgie/README.md index 1f3bdfc..3c0b28f 100644 --- a/ml_mgie/README.md +++ b/ml_mgie/README.md @@ -1,16 +1,33 @@ # ML-MGIE Packaging -**Work In Progress**: package ml-mgie, simplify dependencies, make compatible with MPS, CUDA and CPU - -Packaging contributors: -- Paul Asquin +**Work In Progress**: refacto, package, simplify dependencies, make compatible with MPS, CUDA and CPU +- by Paul Asquin ## Installation ```bash +make venv +source venv_*/bin/activate poetry install ``` +## Models download +Temporary from unofficial [huggingface.co/paulasquin/ml-mgie](https://huggingface.co/paulasquin/ml-mgie) for simplification conveniance. +```bash +git lfs install +git clone https://huggingface.co/paulasquin/ml-mgie ./data +``` + ## Demo ```bash poetry run python ml_mgie/demo/inference.py +``` + +## Typing check and tests +```bash +poetry run make tests +``` + +## Usage +```bash +poetry run python -m ml_mgie.main --input_path _input/0.jpg --instruction "make the frame red" --output_path red_glasses.jpg ``` \ No newline at end of file diff --git a/ml_mgie/demo/demo.py b/ml_mgie/demo/demo.py new file mode 100644 index 0000000..63ca6c8 --- /dev/null +++ b/ml_mgie/demo/demo.py @@ -0,0 +1,51 @@ +import os +import shutil +from pathlib import Path + +from ml_mgie.mgie import MGIE, MGIEParams +from PIL import Image +from tqdm import tqdm + +SEED = 13331 +CFG_TXT = 7.5 +CFG_IMG = 1.5 +params = MGIEParams() +mgie = MGIE(params=params) +input_path = Path("_input") +output_path = Path("_output") +os.makedirs(output_path, exist_ok=True) + +ins = [ + "make the frame red", + "turn the day into night", + "give him a beard", + "make cottage a mansion", + "remove yellow object from dogs paws", + "change the hair from red to blue", + "remove the text", + "increase the image contrast", + "remove the people in the background", + "please make this photo professional looking", + "darken the image, sharpen it", + "photoshop the girl out", + "make more brightness", + "take away the brown filter form the image", + "add more contrast to simulate more light", + "dark on rgb", + "make the face happy", + "change view as ocean", + "replace basketball with soccer ball", + "let the floor be made of wood", +] +for i in tqdm(range(len(ins))): + image_input_path = input_path / f"{i}.jpg" + image = Image.open(image_input_path).convert("RGB") + instruction = ins[i] + + result_image, inner_thought = mgie.edit( + image=image, + instruction=instruction, + ) + print(f"Inner thought: {inner_thought}") + result_image.save(output_path / f"{i}-out.jpg") + shutil.copy(image_input_path, output_path / f"{i}-in.jpg") diff --git a/ml_mgie/demo/inference.py b/ml_mgie/demo/inference.py deleted file mode 100644 index 5232eeb..0000000 --- a/ml_mgie/demo/inference.py +++ /dev/null @@ -1,88 +0,0 @@ -import os -import shutil -from pathlib import Path - -import torch -from ml_mgie.mgie import MGIE, MGIEParams -from ml_mgie.utils import remove_alter -from PIL import Image -from tqdm import tqdm - -SEED = 13331 -CFG_TXT = 7.5 -CFG_IMG = 1.5 -params = MGIEParams() -mgie = MGIE(params=params) -input_path = Path("_input") -output_path = Path("_output") -os.makedirs(output_path, exist_ok=True) - -ins = [ - "make the frame red", - "turn the day into night", - "give him a beard", - "make cottage a mansion", - "remove yellow object from dogs paws", - "change the hair from red to blue", - "remove the text", - "increase the image contrast", - "remove the people in the background", - "please make this photo professional looking", - "darken the image, sharpen it", - "photoshop the girl out", - "make more brightness", - "take away the brown filter form the image", - "add more contrast to simulate more light", - "dark on rgb", - "make the face happy", - "change view as ocean", - "replace basketball with soccer ball", - "let the floor be made of wood", -] -for i in tqdm(range(len(ins))): - image_input_path = input_path / f"{i}.jpg" - image = Image.open(image_input_path).convert("RGB") - instruction = ins[i] - - # Prepare inputs - img = mgie.prepare_img(image) - prompt_tensor_ids, mask = mgie.prepare_prompt_id_and_mask(instruction) - with torch.inference_mode(): - out = mgie.model.generate( - prompt_tensor_ids.unsqueeze(dim=0).to(params.device), - images=img.half().unsqueeze(dim=0).to(params.device), - attention_mask=mask.unsqueeze(dim=0).to(params.device), - do_sample=False, - max_new_tokens=96, - num_beams=1, - no_repeat_ngram_size=3, - return_dict_in_generate=True, - output_hidden_states=True, - ) - import pdb - - pdb.set_trace() - # Here out is nonesense: "Pres flash togful calledgot At commitilli split sent" - out, hid = ( - out["sequences"][0].tolist(), - torch.cat([x[-1] for x in out["hidden_states"]], dim=1)[0], - ) - - p = min(out.index(32003) - 1 if 32003 in out else len(hid) - 9, len(hid) - 9) - hid = hid[p : p + 8] - - out = remove_alter(mgie.tokenizer.decode(out)) - emb = mgie.model.edit_head(hid.unsqueeze(dim=0), mgie.emb) - res: Image.Image = mgie.pipe( - image=Image.open(image_input_path).convert("RGB"), - prompt_embeds=emb, - negative_prompt_embeds=mgie.null, - generator=torch.Generator(device=params.device).manual_seed(SEED), - guidance_scale=CFG_TXT, - image_guidance_scale=CFG_IMG, - ).images[0] - # Save results before/after - print(f"Instruction: {instruction}") - print(f"Output: {out}") - shutil.copy(image_input_path, output_path / f"{i}-in.jpg") - res.save(output_path / f"{i}-out.jpg") diff --git a/ml_mgie/ml_mgie/llava_conversation.py b/ml_mgie/ml_mgie/llava_conversation.py index c7d3216..0052999 100644 --- a/ml_mgie/ml_mgie/llava_conversation.py +++ b/ml_mgie/ml_mgie/llava_conversation.py @@ -1,3 +1,4 @@ +# from https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/llava/conversation.py import dataclasses from enum import Enum, auto from typing import List diff --git a/ml_mgie/ml_mgie/main.py b/ml_mgie/ml_mgie/main.py new file mode 100644 index 0000000..7e4988d --- /dev/null +++ b/ml_mgie/ml_mgie/main.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path +from typing import Any + +from ml_mgie.mgie import MGIE, MGIEParams +from PIL import Image +from simple_parsing import ArgumentParser + + +@dataclass +class Command: + def run(self): + raise NotImplementedError() + + +@dataclass +class MGIECommand: + """""" + + input_path: Path + instruction: str + output_path: Path + params: MGIEParams = MGIEParams() + + @cached_property + def mgie(self) -> MGIE: + return MGIE(params=self.params) + + def run(self): + print( + f"Running MGIE command with instruction: {self.instruction} onto image: {self.input_path}" + ) + image = Image.open(self.input_path).convert("RGB") + result_image, inner_thought = self.mgie.edit( + image=image, + instruction=self.instruction, + ) + print(f"Inner thought: {inner_thought}") + result_image.save(self.output_path) + print(f"Saved result image to: {self.output_path}") + + +@dataclass +class Program: + command: MGIECommand + + def run(self) -> Any: + return self.command.run() + + +if __name__ == "__main__": + parser = ArgumentParser( + prog="ml-mgie.main", + description=""" + ML-MGIE Rag: Guiding Instruction-based Image Editing via Multimodal Large Language Models. + """, + ) + parser.add_arguments(Program, dest="program") + args = parser.parse_args() + program: Program = args.program + program.run() diff --git a/ml_mgie/ml_mgie/mgie.py b/ml_mgie/ml_mgie/mgie.py index c664256..7eb9752 100644 --- a/ml_mgie/ml_mgie/mgie.py +++ b/ml_mgie/ml_mgie/mgie.py @@ -5,35 +5,45 @@ import diffusers import torch import transformers +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix import ( + StableDiffusionInstructPix2PixPipeline, +) from PIL import Image from .base import DEFAULT_DEVICE - -# from .mgie_llava import LlavaLlamaForCausalLM -from .llava import LlavaLlamaForCausalLM from .llava_conversation import conv_templates -from .utils import crop_resize - -# from llava.conversation import conv_templates -# from llava.model import * - +from .mgie_llava import LlavaLlamaForCausalLM +from .utils import crop_resize, remove_alter DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" -PATH_LLAVA = Path("./_ckpt/LLaVA-7B-v1") -PATH_MLLM = Path("./_ckpt/mgie_7b/mllm.pt") -PATH_UNET = Path("./_ckpt/mgie_7b/unet.pt") - -assert PATH_LLAVA.exists() -assert PATH_MLLM.exists() -assert PATH_UNET.exists() @dataclass class MGIEParams: device: torch.device = DEFAULT_DEVICE + dtype: torch.dtype = torch.float16 + models_path: Path = Path("./data") + seed: int = 13331 + cfg_txt: float = 7.5 + cfg_img: float = 1.5 + + @property + def mllm_path(self) -> Path: + assert (path := self.models_path / "mgie_7b/mllm.pt").exists() + return path + + @property + def unet_path(self) -> Path: + assert (path := self.models_path / "mgie_7b/unet.pt").exists() + return path + + @property + def llava_path(self) -> Path: + assert (path := self.models_path / "LLaVA-7B-v1").exists() + return path class MGIE: @@ -44,26 +54,25 @@ def __init__(self, params: MGIEParams = MGIEParams()) -> None: self.image_processor: transformers.CLIPImageProcessor = None self.image_token_len: int = None self.emb: torch.Tensor = None - self.pipe: diffusers.StableDiffusionInstructPix2PixPipeline = None + self.pipe: StableDiffusionInstructPix2PixPipeline = None self._set_model() def _set_model(self): - tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA) + transformers.logging.set_verbosity_error() + # Prepare llava model = LlavaLlamaForCausalLM.from_pretrained( - PATH_LLAVA.absolute(), - # low_cpu_mem_usage=True, - torch_dtype=torch.float16, + self.params.llava_path, + torch_dtype=self.params.dtype, use_cache=True, ).to(self.params.device) - model.model = model.model.to(self.params.device) - model.model.vision_tower[0] = model.get_vision_tower().to(self.params.device) - model.lm_head = model.lm_head.to(self.params.device) - model.edit_head = model.edit_head.to(self.params.device) + # Prepare CLIP image_processor = transformers.CLIPImageProcessor.from_pretrained( - model.config.mm_vision_tower, torch_dtype=torch.float16 + model.config.mm_vision_tower, torch_dtype=self.params.dtype ) + # Prepare tokenizer + tokenizer = transformers.AutoTokenizer.from_pretrained(self.params.llava_path) tokenizer.padding_side = "left" tokenizer.add_tokens( [ @@ -79,11 +88,18 @@ def _set_model(self): special_tokens=True, ) model.resize_token_embeddings(len(tokenizer)) - # ckpt = torch.load(PATH_MLLM, map_location=self.params.device) # TO DEVICE? - ckpt = torch.load(PATH_MLLM, map_location="cpu") - # incompatible_keys = model.load_state_dict(ckpt, strict=False, assign=True) + ckpt = torch.load(self.params.mllm_path, map_location="cpu") incompatible_keys = model.load_state_dict(ckpt, strict=False) + transformers.logging.set_verbosity_warning() + # Patch model + model.model.vision_tower[0] = model.get_vision_tower().to( + self.params.device, dtype=self.params.dtype + ) + model.lm_head = model.lm_head.to(self.params.device, dtype=self.params.dtype) + model.edit_head = model.edit_head.to( + self.params.device, dtype=self.params.dtype + ) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: @@ -91,16 +107,16 @@ def _set_model(self): [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True ) + # Patch CLIP vision_tower = model.get_vision_tower() - vision_tower: transformers.CLIPVisionModel = ( - transformers.CLIPVisionModel.from_pretrained( - vision_tower.config._name_or_path, - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - ).to(self.params.device) - ) + vision_tower = transformers.CLIPVisionModel.from_pretrained( + vision_tower.config._name_or_path, + torch_dtype=self.params.dtype, + low_cpu_mem_usage=True, + ).to(self.params.device) model.model.vision_tower[0] = vision_tower - vision_config: transformers.PretrainedConfig = vision_tower.config + # Patch CLIP config + vision_config = vision_tower.config vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( [DEFAULT_IMAGE_PATCH_TOKEN] )[0] @@ -114,29 +130,27 @@ def _set_model(self): ) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 - # model = model.to(self.params.device) - _ = model.eval() - emb = ckpt["emb"].to(self.params.device) + # Prepare placeholders + emb = ckpt["emb"].to(self.params.device, dtype=self.params.dtype) with torch.inference_mode(): null = model.edit_head( - torch.zeros(1, 8, 4096).half().to(self.params.device), + torch.zeros(1, 8, 4096).to(self.params.device, dtype=self.params.dtype), emb, ) + # Prepare diffusion pipeline pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained( - "timbrooks/instruct-pix2pix", - torch_dtype=torch.float16, # , safety_checker=None - ).to(self.params.device) + "timbrooks/instruct-pix2pix", safety_checker=None + ) pipe.set_progress_bar_config(disable=True) - """pipe.unet.load_state_dict( - torch.load(PATH_UNET.absolute(), map_location=self.params.device), - assign=True, - strict=True, - ) # TO DEVICE?""" pipe.unet.load_state_dict( - torch.load(PATH_UNET.absolute(), map_location="cpu"), - ) # TO DEVICE? + torch.load(self.params.unet_path.absolute(), map_location="cpu"), + strict=True, + ) + pipe = pipe.to(device=self.params.device, dtype=self.params.dtype) + # Set attributes + model.eval() self.model = model self.tokenizer = tokenizer self.image_processor = image_processor @@ -148,10 +162,10 @@ def _set_model(self): def prepare_img(self, image: Image.Image) -> torch.Tensor: """image: PIL.Image.Image, Pillow RGB image""" img = crop_resize(image) - img = self.image_processor.preprocess(img, return_tensors="pt")["pixel_values"][ - 0 - ] - return img + img_tensor: torch.Tensor = self.image_processor.preprocess( + img, return_tensors="pt" + )["pixel_values"][0] + return img_tensor.to(device=self.params.device, dtype=self.params.dtype) def prepare_prompt_id_and_mask( self, instruction: str @@ -169,6 +183,55 @@ def prepare_prompt_id_and_mask( conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() prompt_tokenized = self.tokenizer(prompt) - prompt_tensor_ids = torch.as_tensor(prompt_tokenized["input_ids"]) - mask = torch.as_tensor(prompt_tokenized["attention_mask"]) + prompt_tensor_ids = torch.as_tensor( + prompt_tokenized["input_ids"], + device=self.params.device, + ) + mask = torch.as_tensor( + prompt_tokenized["attention_mask"], + device=self.params.device, + ) return prompt_tensor_ids, mask + + def edit(self, image: Image.Image, instruction: str) -> Tuple[Image.Image, str]: + """ + image: PIL.Image.Image, Pillow RGB image + instruction: str, edition to perform on image + """ + # Prepare inputs + img = self.prepare_img(image) + prompt_tensor_ids, mask = self.prepare_prompt_id_and_mask(instruction) + with torch.inference_mode(): + out = self.model.generate( + prompt_tensor_ids.unsqueeze(dim=0), + images=img.unsqueeze(dim=0), + attention_mask=mask.unsqueeze(dim=0), + do_sample=False, + max_new_tokens=96, + num_beams=1, + no_repeat_ngram_size=3, + return_dict_in_generate=True, + output_hidden_states=True, + ) + + out, hid = ( + out["sequences"][0].tolist(), + torch.cat([x[-1] for x in out["hidden_states"]], dim=1)[0], + ) + p = out.index(32003) - 1 if 32003 in out else len(hid) - 9 + p = min(p, len(hid) - 9) + hid = hid[p : p + 8] + + inner_thoughts = remove_alter(self.tokenizer.decode(out)) + emb = self.model.edit_head(hid.unsqueeze(dim=0), self.emb) + res: Image.Image = self.pipe( + image=image, + prompt_embeds=emb, + negative_prompt_embeds=self.null, + generator=torch.Generator(device=self.params.device).manual_seed( + self.params.seed + ), + guidance_scale=self.params.cfg_txt, + image_guidance_scale=self.params.cfg_img, + ).images[0] + return res, inner_thoughts diff --git a/ml_mgie/ml_mgie/mgie_llava.py b/ml_mgie/ml_mgie/mgie_llava.py index aee22fe..23cd47e 100644 --- a/ml_mgie/ml_mgie/mgie_llava.py +++ b/ml_mgie/ml_mgie/mgie_llava.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn -from ml_mgie.base import DEFAULT_DEVICE from torch.nn import CrossEntropyLoss from transformers import ( AutoConfig, @@ -41,7 +40,6 @@ class LlavaLlamaModel(LlamaModel): def __init__(self, config: LlamaConfig): super(LlavaLlamaModel, self).__init__(config) - self.to_device = DEFAULT_DEVICE if hasattr(config, "mm_vision_tower"): # HACK: for FSDP diff --git a/ml_mgie/ml_mgie/utils.py b/ml_mgie/ml_mgie/utils.py index 12eda69..13f1ea1 100644 --- a/ml_mgie/ml_mgie/utils.py +++ b/ml_mgie/ml_mgie/utils.py @@ -1,26 +1,28 @@ -# TODO: add typing -def crop_resize(f, sz=512): - w, h = f.size +from PIL import Image + + +def crop_resize(image: Image.Image, size: int = 512) -> Image.Image: + w, h = image.size if w > h: p = (w - h) // 2 - f = f.crop([p, 0, p + h, h]) + image = image.crop([p, 0, p + h, h]) elif h > w: p = (h - w) // 2 - f = f.crop([0, p, w, p + w]) - f = f.resize([sz, sz]) - return f + image = image.crop([0, p, w, p + w]) + image = image.resize([size, size]) + return image -def remove_alter(s): # hack expressive instruction - if "ASSISTANT:" in s: - s = s[s.index("ASSISTANT:") + 10 :].strip() - if "" in s: - s = s[: s.index("")].strip() - if "alternative" in s.lower(): - s = s[: s.lower().index("alternative")] - if "[IMG0]" in s: - s = s[: s.index("[IMG0]")] - s = ".".join([s.strip() for s in s.split(".")[:2]]) - if s[-1] != ".": - s += "." - return s.strip() +def remove_alter(prompt: str) -> str: # hack expressive instruction + if "ASSISTANT:" in prompt: + prompt = prompt[prompt.index("ASSISTANT:") + 10 :].strip() + if "" in prompt: + prompt = prompt[: prompt.index("")].strip() + if "alternative" in prompt.lower(): + prompt = prompt[: prompt.lower().index("alternative")] + if "[IMG0]" in prompt: + prompt = prompt[: prompt.index("[IMG0]")] + prompt = ".".join([s.strip() for s in prompt.split(".")[:2]]) + if prompt[-1] != ".": + prompt += "." + return prompt.strip() diff --git a/ml_mgie/tests/units/test_model.py b/ml_mgie/tests/units/test_model.py index c6bb036..a4d53cd 100644 --- a/ml_mgie/tests/units/test_model.py +++ b/ml_mgie/tests/units/test_model.py @@ -1,19 +1,14 @@ -import torch from pathlib import Path -from PIL import Image -from ml_mgie.mgie import ( - DEFAULT_IM_END_TOKEN, - DEFAULT_IM_START_TOKEN, - DEFAULT_IMAGE_PATCH_TOKEN, - MGIE, - MGIEParams, -) -from ml_mgie.utils import remove_alter +import torch +from ml_mgie.mgie import MGIE, MGIEParams +from ml_mgie.utils import remove_alter +from PIL import Image TEST_IMAGE_PATH = Path(__file__).parents[1] / "data/0.jpg" TEST_INSTRUCTION = "make the frame red" assert TEST_IMAGE_PATH.exists() +# params = MGIEParams(device=torch.device("cpu")) params = MGIEParams() mgie = MGIE(params=params) @@ -40,9 +35,9 @@ def test_generate(): prompt_tensor_ids, mask = mgie.prepare_prompt_id_and_mask(instruction) with torch.inference_mode(): out = mgie.model.generate( - prompt_tensor_ids.unsqueeze(dim=0).to(params.device), - images=img.half().unsqueeze(dim=0).to(params.device), - attention_mask=mask.unsqueeze(dim=0).to(params.device), + prompt_tensor_ids.unsqueeze(dim=0), + images=img.unsqueeze(dim=0), + attention_mask=mask.unsqueeze(dim=0), do_sample=False, max_new_tokens=96, num_beams=1, @@ -52,4 +47,5 @@ def test_generate(): ) out = out["sequences"][0].tolist() out = remove_alter(mgie.tokenizer.decode(out)) - assert not "Pres flash togful calledgot" in out, f"Nonesense: {out}" + # Ensuring no dtype-introduced nonesense MPS/torch 2.2.0: "Pres flash togful calledgot At commitilli split sent" + assert "The frame would be red" in out diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..d93616b --- /dev/null +++ b/mypy.ini @@ -0,0 +1,8 @@ +[mypy] +no_strict_optional = True +ignore_missing_imports = True +exclude = (?x)( + llava_conversation\.py + | mgie_llava\.py + ) +namespace_packages = False diff --git a/pyproject.toml b/pyproject.toml index 89b7231..426a1c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,6 @@ [tool.poetry] name = "ml-mgie" -packages = [ - { include = "ml_mgie", from = "./ml_mgie" }, - # { include = "llava", from = "./LLaVa" }, -] +packages = [{ include = "ml_mgie", from = "./ml_mgie" }] version = "v0.0.0" description = "" authors = [ @@ -13,7 +10,7 @@ authors = [ "William Yang Wang", "Yinfei Yang", "Zhe Gan", - "Paul Asquin", # only as package contributor + "Paul Asquin", # as porting and package contributor ] [build-system] @@ -24,16 +21,17 @@ build-backend = "poetry.core.masonry.api" [tool.poetry.dependencies] python = ">=3.10,<3.11" -torch = "^2.2.0" +torch = "2.1.0" tqdm = "^4.66.1" transformers = "^4.37.2" diffusers = "^0.26.2" sentencepiece = "^0.1.99" protobuf = "^4.25.2" accelerate = "^0.26.1" -clip = {git = "https://github.com/openai/CLIP.git"} +clip = { git = "https://github.com/openai/CLIP.git" } evaluate = "^0.4.1" tokenizers = "^0.15.1" -torchvision = "^0.17.0" +torchvision = "^0.16.0" deepspeed = "^0.13.1" ninja = "^1.11.1.1" +simple-parsing = "^0.1.5" From a6a0e569f52772d83317b57241163bec9b19760a Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Fri, 9 Feb 2024 22:36:21 +0100 Subject: [PATCH 12/23] add max size --- ml_mgie/README.md | 2 +- ml_mgie/ml_mgie/mgie.py | 2 ++ pyproject.toml | 4 ++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ml_mgie/README.md b/ml_mgie/README.md index 3c0b28f..bf965e9 100644 --- a/ml_mgie/README.md +++ b/ml_mgie/README.md @@ -29,5 +29,5 @@ poetry run make tests ## Usage ```bash -poetry run python -m ml_mgie.main --input_path _input/0.jpg --instruction "make the frame red" --output_path red_glasses.jpg +poetry run python -m ml_mgie.main --input_path _input/0.jpg --instruction "make the frame red" --output_path red_glasses.jpg --max_size 512 ``` \ No newline at end of file diff --git a/ml_mgie/ml_mgie/mgie.py b/ml_mgie/ml_mgie/mgie.py index 7eb9752..d9d62bf 100644 --- a/ml_mgie/ml_mgie/mgie.py +++ b/ml_mgie/ml_mgie/mgie.py @@ -29,6 +29,7 @@ class MGIEParams: seed: int = 13331 cfg_txt: float = 7.5 cfg_img: float = 1.5 + max_size: int = 512 @property def mllm_path(self) -> Path: @@ -199,6 +200,7 @@ def edit(self, image: Image.Image, instruction: str) -> Tuple[Image.Image, str]: instruction: str, edition to perform on image """ # Prepare inputs + image.thumbnail((self.params.max_size, self.params.max_size)) img = self.prepare_img(image) prompt_tensor_ids, mask = self.prepare_prompt_id_and_mask(instruction) with torch.inference_mode(): diff --git a/pyproject.toml b/pyproject.toml index 426a1c4..6157481 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry.dependencies] python = ">=3.10,<3.11" -torch = "2.1.0" +torch = "2.2.0" tqdm = "^4.66.1" transformers = "^4.37.2" diffusers = "^0.26.2" @@ -31,7 +31,7 @@ accelerate = "^0.26.1" clip = { git = "https://github.com/openai/CLIP.git" } evaluate = "^0.4.1" tokenizers = "^0.15.1" -torchvision = "^0.16.0" +torchvision = "^0.17.0" deepspeed = "^0.13.1" ninja = "^1.11.1.1" simple-parsing = "^0.1.5" From 710f92c2daf08f53ae1b9edf7d493e589a326391 Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Fri, 9 Feb 2024 22:36:54 +0100 Subject: [PATCH 13/23] allow none maxsize --- ml_mgie/ml_mgie/mgie.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ml_mgie/ml_mgie/mgie.py b/ml_mgie/ml_mgie/mgie.py index d9d62bf..e709ebd 100644 --- a/ml_mgie/ml_mgie/mgie.py +++ b/ml_mgie/ml_mgie/mgie.py @@ -200,7 +200,8 @@ def edit(self, image: Image.Image, instruction: str) -> Tuple[Image.Image, str]: instruction: str, edition to perform on image """ # Prepare inputs - image.thumbnail((self.params.max_size, self.params.max_size)) + if self.params.max_size: + image.thumbnail((self.params.max_size, self.params.max_size)) img = self.prepare_img(image) prompt_tensor_ids, mask = self.prepare_prompt_id_and_mask(instruction) with torch.inference_mode(): From c9100d9ef50265b8eadccadb4577bb7c334c1239 Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Fri, 9 Feb 2024 23:04:14 +0100 Subject: [PATCH 14/23] default to float32 --- ml_mgie/ml_mgie/mgie.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml_mgie/ml_mgie/mgie.py b/ml_mgie/ml_mgie/mgie.py index e709ebd..ee8caaf 100644 --- a/ml_mgie/ml_mgie/mgie.py +++ b/ml_mgie/ml_mgie/mgie.py @@ -24,7 +24,7 @@ @dataclass class MGIEParams: device: torch.device = DEFAULT_DEVICE - dtype: torch.dtype = torch.float16 + dtype: torch.dtype = torch.float32 models_path: Path = Path("./data") seed: int = 13331 cfg_txt: float = 7.5 From e5b4ea535d91050201c303a97e4ede682c03eead Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Sat, 10 Feb 2024 01:14:27 +0100 Subject: [PATCH 15/23] add app gradio --- ml_mgie/README.md | 6 ++- ml_mgie/ml_mgie/app.py | 89 +++++++++++++++++++++++++++++++++++++++++ ml_mgie/ml_mgie/main.py | 6 +++ ml_mgie/ml_mgie/mgie.py | 12 ++++-- pyproject.toml | 1 + 5 files changed, 110 insertions(+), 4 deletions(-) create mode 100644 ml_mgie/ml_mgie/app.py diff --git a/ml_mgie/README.md b/ml_mgie/README.md index bf965e9..448ac15 100644 --- a/ml_mgie/README.md +++ b/ml_mgie/README.md @@ -11,7 +11,7 @@ poetry install ``` ## Models download -Temporary from unofficial [huggingface.co/paulasquin/ml-mgie](https://huggingface.co/paulasquin/ml-mgie) for simplification conveniance. +Temporary from unofficial [huggingface.co/paulasquin/ml-mgie](https://huggingface.co/paulasquin/ml-mgie) for simplification convenience. ```bash git lfs install git clone https://huggingface.co/paulasquin/ml-mgie ./data @@ -30,4 +30,8 @@ poetry run make tests ## Usage ```bash poetry run python -m ml_mgie.main --input_path _input/0.jpg --instruction "make the frame red" --output_path red_glasses.jpg --max_size 512 +``` + +```bash +poetry run python -m ml_mgie.app ``` \ No newline at end of file diff --git a/ml_mgie/ml_mgie/app.py b/ml_mgie/ml_mgie/app.py new file mode 100644 index 0000000..ad555fc --- /dev/null +++ b/ml_mgie/ml_mgie/app.py @@ -0,0 +1,89 @@ +import os +from datetime import datetime +from pathlib import Path + +import gradio as gr +from ml_mgie.mgie import MGIE, MGIEParams +from PIL import Image + +DEBUG_PATH = Path("debug") +os.makedirs(DEBUG_PATH, exist_ok=True) +mgie = MGIE(params=MGIEParams()) + + +def go_mgie( + image: Image.Image, + instruction: str, + seed: int, + cfg_txt: float, + cfg_img: float, + max_size: int, + telemetry: bool, +): + params = MGIEParams(seed=seed, cfg_txt=cfg_txt, cfg_img=cfg_img, max_size=max_size) + mgie.params = params + name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + + if telemetry: + image.save(DEBUG_PATH / f"{name}-in.jpg") + + image.thumbnail((params.max_size, params.max_size)) + if telemetry: + image.save(DEBUG_PATH / f"{name}-thumb.jpg") + + result_image, inner_thoughts = mgie.edit(image=image, instruction=instruction) + if telemetry: + result_image.save(DEBUG_PATH / f"{name}-zout.jpg") + return result_image, inner_thoughts + + +with gr.Blocks() as app: + gr.Markdown( + "# Guiding Instruction-based Image Editing via Multimodal Large Language Models" + ) + with gr.Row(): + input_image, result_image = [ + gr.Image( + label="Input Image", + interactive=True, + height="500px", + type="pil", + image_mode="RGB", + ), + gr.Image( + label="Goal Image", type="pil", interactive=False, image_mode="RGB" + ), + ] + with gr.Row(): + instruction, inner_thoughts = [ + gr.Textbox(label="Instruction", interactive=True), + gr.Textbox(label="Expressive Instruction", interactive=False), + ] + with gr.Row(): + telemetry, seed, cfg_txt, cfg_img, max_size = [ + gr.Checkbox(label="Telemetry", value=True, interactive=True), + gr.Number(value=42, label="Seed", interactive=True, precision=0), + gr.Number(value=7.5, label="Text CFG", interactive=True), + gr.Number(value=1.5, label="Image CFG", interactive=True), + gr.Number( + minimum=1, + maximum=1024, + value=512, + precision=0, + label="Maximum Size", + interactive=True, + ), + ] + with gr.Row(): + btn_sub = gr.Button("Submit") + + btn_sub.click( + fn=go_mgie, + inputs=[input_image, instruction, seed, cfg_txt, cfg_img, max_size, telemetry], + outputs=[result_image, inner_thoughts], + concurrency_limit=1, + ) + + +app.queue() +app.launch(server_port=7122) diff --git a/ml_mgie/ml_mgie/main.py b/ml_mgie/ml_mgie/main.py index 7e4988d..1b77a63 100644 --- a/ml_mgie/ml_mgie/main.py +++ b/ml_mgie/ml_mgie/main.py @@ -21,6 +21,7 @@ class MGIECommand: input_path: Path instruction: str output_path: Path + save_thumbnail: bool = False params: MGIEParams = MGIEParams() @cached_property @@ -32,6 +33,11 @@ def run(self): f"Running MGIE command with instruction: {self.instruction} onto image: {self.input_path}" ) image = Image.open(self.input_path).convert("RGB") + if self.params.max_size: + image.thumbnail((self.params.max_size, self.params.max_size)) + if self.save_thumbnail: + image.save(self.output_path.with_suffix(".thumb.jpg")) + result_image, inner_thought = self.mgie.edit( image=image, instruction=self.instruction, diff --git a/ml_mgie/ml_mgie/mgie.py b/ml_mgie/ml_mgie/mgie.py index ee8caaf..c8503e2 100644 --- a/ml_mgie/ml_mgie/mgie.py +++ b/ml_mgie/ml_mgie/mgie.py @@ -24,13 +24,19 @@ @dataclass class MGIEParams: device: torch.device = DEFAULT_DEVICE - dtype: torch.dtype = torch.float32 + half: bool = False models_path: Path = Path("./data") seed: int = 13331 cfg_txt: float = 7.5 cfg_img: float = 1.5 max_size: int = 512 + @property + def dtype(self) -> torch.dtype: + if self.half: + return torch.float16 + return torch.float32 + @property def mllm_path(self) -> Path: assert (path := self.models_path / "mgie_7b/mllm.pt").exists() @@ -227,7 +233,7 @@ def edit(self, image: Image.Image, instruction: str) -> Tuple[Image.Image, str]: inner_thoughts = remove_alter(self.tokenizer.decode(out)) emb = self.model.edit_head(hid.unsqueeze(dim=0), self.emb) - res: Image.Image = self.pipe( + result_image: Image.Image = self.pipe( image=image, prompt_embeds=emb, negative_prompt_embeds=self.null, @@ -237,4 +243,4 @@ def edit(self, image: Image.Image, instruction: str) -> Tuple[Image.Image, str]: guidance_scale=self.params.cfg_txt, image_guidance_scale=self.params.cfg_img, ).images[0] - return res, inner_thoughts + return result_image, inner_thoughts diff --git a/pyproject.toml b/pyproject.toml index 6157481..b71418a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,3 +35,4 @@ torchvision = "^0.17.0" deepspeed = "^0.13.1" ninja = "^1.11.1.1" simple-parsing = "^0.1.5" +gradio = "^4.17.0" From 2ed91bdc0dfe07fc0532b520c5eb90247c47b968 Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Sat, 10 Feb 2024 01:26:24 +0100 Subject: [PATCH 16/23] default size --- ml_mgie/ml_mgie/app.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ml_mgie/ml_mgie/app.py b/ml_mgie/ml_mgie/app.py index ad555fc..1419bc4 100644 --- a/ml_mgie/ml_mgie/app.py +++ b/ml_mgie/ml_mgie/app.py @@ -46,7 +46,6 @@ def go_mgie( gr.Image( label="Input Image", interactive=True, - height="500px", type="pil", image_mode="RGB", ), From 9ddfbfff7dea3373d752cff66583e6ba02d912bd Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Sat, 10 Feb 2024 11:02:22 +0100 Subject: [PATCH 17/23] add back submodule llava to avoid perturbing legacy --- .gitmodules | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitmodules b/.gitmodules index e69de29..0d06bb4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "LLaVA"] + path = LLaVA + url = https://github.com/haotian-liu/LLaVA \ No newline at end of file From 371830bfbfde52456c69b84d3c0d78547ca77063 Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Sat, 10 Feb 2024 11:03:22 +0100 Subject: [PATCH 18/23] add newline --- .gitmodules | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index 0d06bb4..f1b7c9f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "LLaVA"] path = LLaVA - url = https://github.com/haotian-liu/LLaVA \ No newline at end of file + url = https://github.com/haotian-liu/LLaVA From db3346380ce67dba8902ec34a0854cf16f74cd96 Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Sat, 10 Feb 2024 11:14:19 +0100 Subject: [PATCH 19/23] add back submodule llava for legacy --- .gitmodules | 2 +- LLaVA | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) create mode 160000 LLaVA diff --git a/.gitmodules b/.gitmodules index f1b7c9f..d7d0e09 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "LLaVA"] path = LLaVA - url = https://github.com/haotian-liu/LLaVA + url = https://github.com/haotian-liu/LLaVA/tree/7ace501183c4bdec6052ec1a30039cdc3242a67c diff --git a/LLaVA b/LLaVA new file mode 160000 index 0000000..7ace501 --- /dev/null +++ b/LLaVA @@ -0,0 +1 @@ +Subproject commit 7ace501183c4bdec6052ec1a30039cdc3242a67c From 73c5c25e28e180c0effa630b3e39d4bfbd37ecd1 Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Sat, 10 Feb 2024 12:08:50 +0100 Subject: [PATCH 20/23] add inference as decorator, half as default --- ml_mgie/demo/demo.py | 5 +-- ml_mgie/ml_mgie/app.py | 24 +++-------- ml_mgie/ml_mgie/main.py | 2 +- ml_mgie/ml_mgie/mgie.py | 94 +++++++++++++++++++++-------------------- 4 files changed, 56 insertions(+), 69 deletions(-) diff --git a/ml_mgie/demo/demo.py b/ml_mgie/demo/demo.py index 63ca6c8..c468a2f 100644 --- a/ml_mgie/demo/demo.py +++ b/ml_mgie/demo/demo.py @@ -6,10 +6,7 @@ from PIL import Image from tqdm import tqdm -SEED = 13331 -CFG_TXT = 7.5 -CFG_IMG = 1.5 -params = MGIEParams() +params = MGIEParams(half=True, seed=13331, cfg_txt=7.5, cfg_img=1.5, max_size=512) mgie = MGIE(params=params) input_path = Path("_input") output_path = Path("_output") diff --git a/ml_mgie/ml_mgie/app.py b/ml_mgie/ml_mgie/app.py index 1419bc4..66d5811 100644 --- a/ml_mgie/ml_mgie/app.py +++ b/ml_mgie/ml_mgie/app.py @@ -1,14 +1,10 @@ -import os from datetime import datetime -from pathlib import Path import gradio as gr from ml_mgie.mgie import MGIE, MGIEParams from PIL import Image -DEBUG_PATH = Path("debug") -os.makedirs(DEBUG_PATH, exist_ok=True) -mgie = MGIE(params=MGIEParams()) +mgie = MGIE(params=MGIEParams(half=True)) def go_mgie( @@ -18,22 +14,15 @@ def go_mgie( cfg_txt: float, cfg_img: float, max_size: int, - telemetry: bool, ): + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + print(f"{timestamp} processing image with instruction: {instruction}") + params = MGIEParams(seed=seed, cfg_txt=cfg_txt, cfg_img=cfg_img, max_size=max_size) mgie.params = params - name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - - if telemetry: - image.save(DEBUG_PATH / f"{name}-in.jpg") image.thumbnail((params.max_size, params.max_size)) - if telemetry: - image.save(DEBUG_PATH / f"{name}-thumb.jpg") - result_image, inner_thoughts = mgie.edit(image=image, instruction=instruction) - if telemetry: - result_image.save(DEBUG_PATH / f"{name}-zout.jpg") return result_image, inner_thoughts @@ -59,8 +48,7 @@ def go_mgie( gr.Textbox(label="Expressive Instruction", interactive=False), ] with gr.Row(): - telemetry, seed, cfg_txt, cfg_img, max_size = [ - gr.Checkbox(label="Telemetry", value=True, interactive=True), + seed, cfg_txt, cfg_img, max_size = [ gr.Number(value=42, label="Seed", interactive=True, precision=0), gr.Number(value=7.5, label="Text CFG", interactive=True), gr.Number(value=1.5, label="Image CFG", interactive=True), @@ -78,7 +66,7 @@ def go_mgie( btn_sub.click( fn=go_mgie, - inputs=[input_image, instruction, seed, cfg_txt, cfg_img, max_size, telemetry], + inputs=[input_image, instruction, seed, cfg_txt, cfg_img, max_size], outputs=[result_image, inner_thoughts], concurrency_limit=1, ) diff --git a/ml_mgie/ml_mgie/main.py b/ml_mgie/ml_mgie/main.py index 1b77a63..6ac0135 100644 --- a/ml_mgie/ml_mgie/main.py +++ b/ml_mgie/ml_mgie/main.py @@ -59,7 +59,7 @@ def run(self) -> Any: parser = ArgumentParser( prog="ml-mgie.main", description=""" - ML-MGIE Rag: Guiding Instruction-based Image Editing via Multimodal Large Language Models. + ML-MGIE: Guiding Instruction-based Image Editing via Multimodal Large Language Models. """, ) parser.add_arguments(Program, dest="program") diff --git a/ml_mgie/ml_mgie/mgie.py b/ml_mgie/ml_mgie/mgie.py index c8503e2..7fd85ad 100644 --- a/ml_mgie/ml_mgie/mgie.py +++ b/ml_mgie/ml_mgie/mgie.py @@ -24,8 +24,8 @@ @dataclass class MGIEParams: device: torch.device = DEFAULT_DEVICE - half: bool = False - models_path: Path = Path("./data") + half: bool = True # Weights are half precision by default + models_path: Path = Path("./data") # Path to dir contaning mgie_7b and LLaVA-7B-v1 seed: int = 13331 cfg_txt: float = 7.5 cfg_img: float = 1.5 @@ -67,6 +67,7 @@ def __init__(self, params: MGIEParams = MGIEParams()) -> None: def _set_model(self): transformers.logging.set_verbosity_error() # Prepare llava + tokenizer = transformers.AutoTokenizer.from_pretrained(self.params.llava_path) model = LlavaLlamaForCausalLM.from_pretrained( self.params.llava_path, torch_dtype=self.params.dtype, @@ -79,7 +80,6 @@ def _set_model(self): ) # Prepare tokenizer - tokenizer = transformers.AutoTokenizer.from_pretrained(self.params.llava_path) tokenizer.padding_side = "left" tokenizer.add_tokens( [ @@ -97,6 +97,12 @@ def _set_model(self): model.resize_token_embeddings(len(tokenizer)) ckpt = torch.load(self.params.mllm_path, map_location="cpu") incompatible_keys = model.load_state_dict(ckpt, strict=False) + mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + if mm_use_im_start_end: + tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) transformers.logging.set_verbosity_warning() # Patch model @@ -107,21 +113,15 @@ def _set_model(self): model.edit_head = model.edit_head.to( self.params.device, dtype=self.params.dtype ) - mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) - tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) - if mm_use_im_start_end: - tokenizer.add_tokens( - [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True - ) # Patch CLIP vision_tower = model.get_vision_tower() vision_tower = transformers.CLIPVisionModel.from_pretrained( vision_tower.config._name_or_path, torch_dtype=self.params.dtype, - low_cpu_mem_usage=True, ).to(self.params.device) model.model.vision_tower[0] = vision_tower + # Patch CLIP config vision_config = vision_tower.config vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( @@ -136,6 +136,7 @@ def _set_model(self): [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] ) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 + _ = model.eval() # Prepare placeholders emb = ckpt["emb"].to(self.params.device, dtype=self.params.dtype) @@ -147,17 +148,18 @@ def _set_model(self): # Prepare diffusion pipeline pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained( - "timbrooks/instruct-pix2pix", safety_checker=None + "timbrooks/instruct-pix2pix", + torch_dtype=self.params.dtype, + safety_checker=None, ) pipe.set_progress_bar_config(disable=True) pipe.unet.load_state_dict( - torch.load(self.params.unet_path.absolute(), map_location="cpu"), + torch.load(self.params.unet_path, map_location="cpu"), strict=True, ) - pipe = pipe.to(device=self.params.device, dtype=self.params.dtype) + pipe = pipe.to(device=self.params.device) # Set attributes - model.eval() self.model = model self.tokenizer = tokenizer self.image_processor = image_processor @@ -200,6 +202,7 @@ def prepare_prompt_id_and_mask( ) return prompt_tensor_ids, mask + @torch.inference_mode() def edit(self, image: Image.Image, instruction: str) -> Tuple[Image.Image, str]: """ image: PIL.Image.Image, Pillow RGB image @@ -210,37 +213,36 @@ def edit(self, image: Image.Image, instruction: str) -> Tuple[Image.Image, str]: image.thumbnail((self.params.max_size, self.params.max_size)) img = self.prepare_img(image) prompt_tensor_ids, mask = self.prepare_prompt_id_and_mask(instruction) - with torch.inference_mode(): - out = self.model.generate( - prompt_tensor_ids.unsqueeze(dim=0), - images=img.unsqueeze(dim=0), - attention_mask=mask.unsqueeze(dim=0), - do_sample=False, - max_new_tokens=96, - num_beams=1, - no_repeat_ngram_size=3, - return_dict_in_generate=True, - output_hidden_states=True, - ) + out = self.model.generate( + prompt_tensor_ids.unsqueeze(dim=0), + images=img.unsqueeze(dim=0), + attention_mask=mask.unsqueeze(dim=0), + do_sample=False, + max_new_tokens=96, + num_beams=1, + no_repeat_ngram_size=3, + return_dict_in_generate=True, + output_hidden_states=True, + ) - out, hid = ( - out["sequences"][0].tolist(), - torch.cat([x[-1] for x in out["hidden_states"]], dim=1)[0], - ) - p = out.index(32003) - 1 if 32003 in out else len(hid) - 9 - p = min(p, len(hid) - 9) - hid = hid[p : p + 8] - - inner_thoughts = remove_alter(self.tokenizer.decode(out)) - emb = self.model.edit_head(hid.unsqueeze(dim=0), self.emb) - result_image: Image.Image = self.pipe( - image=image, - prompt_embeds=emb, - negative_prompt_embeds=self.null, - generator=torch.Generator(device=self.params.device).manual_seed( - self.params.seed - ), - guidance_scale=self.params.cfg_txt, - image_guidance_scale=self.params.cfg_img, - ).images[0] + out, hid = ( + out["sequences"][0].tolist(), + torch.cat([x[-1] for x in out["hidden_states"]], dim=1)[0], + ) + p = out.index(32003) - 1 if 32003 in out else len(hid) - 9 + p = min(p, len(hid) - 9) + hid = hid[p : p + 8] + + inner_thoughts = remove_alter(self.tokenizer.decode(out)) + embedding = self.model.edit_head(hid.unsqueeze(dim=0), self.emb) + result_image: Image.Image = self.pipe( + image=image, + prompt_embeds=embedding, + negative_prompt_embeds=self.null, + generator=torch.Generator(device=self.params.device).manual_seed( + self.params.seed + ), + guidance_scale=self.params.cfg_txt, + image_guidance_scale=self.params.cfg_img, + ).images[0] return result_image, inner_thoughts From 84556f935b24349e83dc9b8b1fb2a8b0d411297a Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Sat, 10 Feb 2024 12:13:17 +0100 Subject: [PATCH 21/23] rename fields in app --- ml_mgie/ml_mgie/app.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ml_mgie/ml_mgie/app.py b/ml_mgie/ml_mgie/app.py index 66d5811..8c7ad05 100644 --- a/ml_mgie/ml_mgie/app.py +++ b/ml_mgie/ml_mgie/app.py @@ -39,13 +39,13 @@ def go_mgie( image_mode="RGB", ), gr.Image( - label="Goal Image", type="pil", interactive=False, image_mode="RGB" + label="Generated Image", type="pil", interactive=False, image_mode="RGB" ), ] with gr.Row(): instruction, inner_thoughts = [ gr.Textbox(label="Instruction", interactive=True), - gr.Textbox(label="Expressive Instruction", interactive=False), + gr.Textbox(label="Inner thoughts", interactive=False), ] with gr.Row(): seed, cfg_txt, cfg_img, max_size = [ @@ -57,7 +57,7 @@ def go_mgie( maximum=1024, value=512, precision=0, - label="Maximum Size", + label="Maximum image size", interactive=True, ), ] From 6af3819c82f3cf88f28df370d9d16942774674bd Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Sat, 10 Feb 2024 12:25:40 +0100 Subject: [PATCH 22/23] update package readme --- ml_mgie/README.md | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/ml_mgie/README.md b/ml_mgie/README.md index 448ac15..c031d20 100644 --- a/ml_mgie/README.md +++ b/ml_mgie/README.md @@ -1,7 +1,10 @@ # ML-MGIE Packaging -**Work In Progress**: refacto, package, simplify dependencies, make compatible with MPS, CUDA and CPU -- by Paul Asquin +Inference-oriented module +Added refacto, packaging, dependencies simplification, add compatibility with apple silicon + +Package contributors: +- [Paul Asquin](https://github.com/paulasquin) ## Installation ```bash @@ -19,7 +22,7 @@ git clone https://huggingface.co/paulasquin/ml-mgie ./data ## Demo ```bash -poetry run python ml_mgie/demo/inference.py +poetry run python ml_mgie/demo/demo.py ``` ## Typing check and tests @@ -31,7 +34,14 @@ poetry run make tests ```bash poetry run python -m ml_mgie.main --input_path _input/0.jpg --instruction "make the frame red" --output_path red_glasses.jpg --max_size 512 ``` +Get more information on CLI using +```bash +poetry run python -m ml_mgie.main --help +``` +## Gradio App ```bash poetry run python -m ml_mgie.app -``` \ No newline at end of file +``` + + From 4cdba2f3413bd76a61f90dd6898e364d4191c84e Mon Sep 17 00:00:00 2001 From: Paul Asquin Date: Sat, 10 Feb 2024 12:42:27 +0100 Subject: [PATCH 23/23] update tests and check image result with ssim --- ml_mgie/tests/units/test_model.py | 33 +++++++++++-------------------- pyproject.toml | 1 + 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/ml_mgie/tests/units/test_model.py b/ml_mgie/tests/units/test_model.py index a4d53cd..cd68009 100644 --- a/ml_mgie/tests/units/test_model.py +++ b/ml_mgie/tests/units/test_model.py @@ -1,14 +1,12 @@ from pathlib import Path -import torch from ml_mgie.mgie import MGIE, MGIEParams -from ml_mgie.utils import remove_alter from PIL import Image +from SSIM_PIL import compare_ssim TEST_IMAGE_PATH = Path(__file__).parents[1] / "data/0.jpg" TEST_INSTRUCTION = "make the frame red" -assert TEST_IMAGE_PATH.exists() -# params = MGIEParams(device=torch.device("cpu")) + params = MGIEParams() mgie = MGIE(params=params) @@ -17,6 +15,7 @@ def test_prepare(): # Prepare prompt instruction = TEST_INSTRUCTION image_path = TEST_IMAGE_PATH + assert image_path.exists() prompt_ids, prompt_mask = mgie.prepare_prompt_id_and_mask(instruction=instruction) assert instruction in mgie.tokenizer.decode(prompt_ids) @@ -31,21 +30,11 @@ def test_generate(): instruction = TEST_INSTRUCTION image_path = TEST_IMAGE_PATH image = Image.open(image_path).convert("RGB") - img = mgie.prepare_img(image) - prompt_tensor_ids, mask = mgie.prepare_prompt_id_and_mask(instruction) - with torch.inference_mode(): - out = mgie.model.generate( - prompt_tensor_ids.unsqueeze(dim=0), - images=img.unsqueeze(dim=0), - attention_mask=mask.unsqueeze(dim=0), - do_sample=False, - max_new_tokens=96, - num_beams=1, - no_repeat_ngram_size=3, - return_dict_in_generate=True, - output_hidden_states=True, - ) - out = out["sequences"][0].tolist() - out = remove_alter(mgie.tokenizer.decode(out)) - # Ensuring no dtype-introduced nonesense MPS/torch 2.2.0: "Pres flash togful calledgot At commitilli split sent" - assert "The frame would be red" in out + result_image, inner_thoughts = mgie.edit(image, instruction) + + # Check inner thoughts contains basic information + words = ["frame", "red", "glasses"] + assert [word in inner_thoughts for word in words].count(True) == len(words) + + # Check result image is not pure hallucination, close to original + assert compare_ssim(image, result_image, GPU=False) >= 0.8 diff --git a/pyproject.toml b/pyproject.toml index b71418a..5502bcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,3 +36,4 @@ deepspeed = "^0.13.1" ninja = "^1.11.1.1" simple-parsing = "^0.1.5" gradio = "^4.17.0" +ssim-pil = "^1.0.14"