应用案例

本文将介绍Matx在多模态相关任务场景下的应用,并以图片/文本多模态为例,聚焦预处理逻辑的实现,相关模型部分不在本文的讨论范围内。源代码

1. 文本模态预处理

以下文本模态的预处理主要实现了BertTokenizer的逻辑。

1.1 辅助函数

下面首先定义一些BertTokenizer style的文本清洗和预变换逻辑(OP)

  • 文本清洗

class TextCleaner:
    """TextCleaner impl by matx."""

    def __init__(self) -> None:
        self.white_regex: matx.Regex = matx.Regex(r"[ \t\n\r\p{Zs}]")
        self.control_regex: matx.Regex = matx.Regex(
            r"[\u0000\ufffd\p{Cc}\p{Cf}\p{Mn}]")

        self.space: bytes = " ".encode()
        self.empty: bytes = "".encode()

    def __call__(self, text: bytes) -> bytes:
        t = self.white_regex.replace(text, self.space)
        return self.control_regex.replace(t, self.empty)
  • 大小写变换

class CaseNormalizer:
    def __init__(self, do_lowercase: bool = False, unicode_norm: str = '') -> None:
        self.do_lowercase: bool = do_lowercase

    def __call__(self, text: bytes) -> bytes:
        if self.do_lowercase:
            return text.lower()
        else:
            return text
  • 标点处理

class PunctuationPadding:
    """Pad a space around the punctuation."""

    def __init__(self):
        self.regex_pattern: matx.Regex = matx.Regex(
            r"([\u0021-\u002f]|[\u003a-\u0040}]|[\u005b-\u0060}]|[\u007b-\u007e]|\p{P})")
        self.replace_pattern: bytes = r" ${1} ".encode()

    def __call__(self, text: bytes) -> bytes:
        return self.regex_pattern.replace(text, self.replace_pattern)

1.2 基于Matx的BertTokenizer

下面实现了基于Matx的BertTokenizer的逻辑,并使用了上文实现的文本清洗,大小写变换等工具。

import matx
from matx.text import WordPieceTokenizer

class MatxBertTokenizer:
    def __init__(self,
                 vocab_path: str,
                 lower_case: bool = False,
                 max_tokens_per_input: int = 256,
                 unk_token: str = '[UNK]'
                 ) -> None:
        """
        matx style BertTokenzier。
        vocab_path: vocabulary path for tokenizer
        lower_case: convert to lowercase or not
        max_tokens_per_input: token length limit
        unk_token: the symbol for unknown tokens
        """
        self.cleaner: TextCleaner = TextCleaner()
        self.normalizer: CaseNormalizer = CaseNormalizer(True)
        self.punc_padding: PunctuationPadding = PunctuationPadding()
        self.max_tokens_per_input: int = max_tokens_per_input
        self.world_piece: Any = WordPieceTokenizer(vocab_path=vocab_path,
                                                   unk_token=unk_token,
                                                   max_bytes_per_token=max_tokens_per_input)
        self.cls_id: int = self.world_piece.tokenize(['[CLS]'])[0]
        self.sep_id: int = self.world_piece.tokenize(['[SEP]'])[0]
        self.pad_id: int = self.world_piece.tokenize(['[PAD]'])[0]


    def __call__(self, texts: List[bytes]) -> Dict[str, matx.NDArray]:
        batch_input_ids: List = []
        batch_input_mask: List = []
        batch_segment_ids: List = []
        for text in texts:
            text = self.cleaner(text)
            text = self.normalizer(text)
            text = self.punc_padding(text)
            terms: List = text.split()
            tokens: List[int] = self.world_piece.tokenize(terms)
            # start to create bert style input
            len_tre: int = self.max_tokens_per_input - 2
            input_ids: List = [self.cls_id] + tokens[:len_tre] + [self.sep_id]
            input_mask: List = [1] * len(input_ids) + [0] * (self.max_tokens_per_input - len(input_ids))
            input_ids = input_ids + [self.pad_id] * (self.max_tokens_per_input - len(input_ids))
            segment_ids = [0] * self.max_tokens_per_input
            batch_input_ids.append(input_ids)
            batch_input_mask.append(input_mask)
            batch_segment_ids.append(segment_ids)
        res: Dict = {}
        res["input_ids"] = matx.NDArray(batch_input_ids, [], "int64")
        res["input_mask"] = matx.NDArray(batch_input_mask, [], "int64")
        res["segment_ids"] = matx.NDArray(batch_segment_ids, [], "int64")
        return res

2. 图片模态预处理

下面以Resnet预处理为例,实现了图片模态的预处理逻辑,涉及到的图片类OP主要有 Decode,RandomResizedCrop,CenterCrop,RandomHorizontalFlip, Normalize等
from typing import List, Dict, Any
import matx
from matx.vision.tv_transforms import Decode, RandomHorizontalFlip, \
RandomResizedCrop, CenterCrop, Normalize, Stack, Transpose, Compose

class MatxImagenetVisionProcessor:
    def __init__(self, device_id: int = -1, is_train: bool = True) -> None:
        self.is_train: bool = is_train
        vision_ops: List = []
        if is_train:  # image transform for training
            vision_ops = [
                matx.script(Decode)(to_rgb=True),
                matx.script(RandomResizedCrop)(size=[224, 224],scale=(0.08,1.0), ratio=(0.75, 1.33)),
                matx.script(RandomHorizontalFlip)(),
                matx.script(Normalize)(mean=[123.675, 116.28, 103.53],
                                       std=[58.395, 57.12, 57.375]),
                matx.script(Stack)(),
                matx.script(Transpose)()
            ]
        else:  # image transform for evaluate
            vision_ops = [
                matx.script(Decode)(to_rgb=True),
                matx.script(CenterCrop)(size=[224, 224]),
                matx.script(Normalize)(mean=[123.675, 116.28, 103.53],
                                       std=[58.395, 57.12, 57.375]),
                matx.script(Stack)(),
                matx.script(Transpose)()
            ]
        self.vision_op: Any = matx.script(Compose)(device_id, vision_ops)

    def __call__(self, images: List[bytes]) -> matx.NDArray:
        return self.vision_op(images)

3. 图片/文本模态整合

现在可以结合文本和图片的预处理生成整体的多模态预处理逻辑
@matx.script
class MultiModalPipeline:
    def __init__(self,
                 vocab_path: str,
                 lower_case: bool = False,
                 max_tokens_per_input: int = 256,
                 unk_token: str = '[UNK]',
                 vision_device_id: int = -1,
                 is_train: bool = True):
        self.text_processor: Any = MatxBertTokenizer(
            vocab_path, lower_case, max_tokens_per_input, unk_token
        )
        self.vision_processor: Any = MatxImagenetVisionProcessor(
            vision_device_id, is_train
        )

    # the input is a batch of data
    # assume each data is like {"text": "some text", "image": b"some image"}
    # the output would be collated, organize the result in any format as you want
    # the code below would output the processed data like
    # {"images": batched_image, "input_ids": batched_input_id, "input_mask": batched_input_mask}
    def __call__(self, data: List[Dict[str, Any]]) -> Dict[str, matx.NDArray]:
        texts: List[str] = [item["text"] for item in data]
        images: List[bytes] = [item["image"] for item in data]
        processed_texts: Dict[str, matx.NDArray] = self.text_processor(texts)
        processed_images: matx.NDArray = self.vision_processor(images)
        res: Dict[str, matx.NDArray] = {}
        for k in processed_texts:
            res[k] = processed_texts[k]
        res["images"] = processed_images
        return res

4. PyTorch Dataloader 示例

有了预处理逻辑之后,即可以将其接入DataLoader,进而进行模型训练。这部分与大家习惯的写法没有什么不同,下面用假数据生成了一个Pytorch DataLoader,供大家参考。
from torch.utils.data import DataLoader

class DemoDataset:
    def __init__(self, is_train=True):
        # If want to run the code, please download the demo image and vocabulary file
        # from github, or just replace them with your own ones
        f = open("demo.jpeg","rb")
        img = f.read()
        f.close()
        text = b"this is a demo"
        self.data = {"text": text, "image": img}
        self.transform = MultiModalPipeline("vocab.txt", is_train=is_train)

    def __len__(self):
        return 100  # some fake number

    def __getitem__(self, indices):
        batch_data = [self.data] * len(indices)
        transformed_data = self.transform(batch_data)
        res = {}
        # convert each matx.NDArray to torch tensor
        for k in transformed_data.keys():
            res[k] = transformed_data[k].torch()
        return res


if __name__ == "__main__":
    dataset = DemoDataset()
    loader = DataLoader(dataset)
    for data in loader:
        print(data["images"].shape)
        print(data["input_ids"].shape)