应用案例¶
本文将介绍Matx在多模态相关任务场景下的应用,并以图片/文本多模态为例,聚焦预处理逻辑的实现,相关模型部分不在本文的讨论范围内。源代码
1. 文本模态预处理¶
以下文本模态的预处理主要实现了BertTokenizer的逻辑。
1.1 辅助函数¶
下面首先定义一些BertTokenizer style的文本清洗和预变换逻辑(OP)
文本清洗
class TextCleaner:
"""TextCleaner impl by matx."""
def __init__(self) -> None:
self.white_regex: matx.Regex = matx.Regex(r"[ \t\n\r\p{Zs}]")
self.control_regex: matx.Regex = matx.Regex(
r"[\u0000\ufffd\p{Cc}\p{Cf}\p{Mn}]")
self.space: bytes = " ".encode()
self.empty: bytes = "".encode()
def __call__(self, text: bytes) -> bytes:
t = self.white_regex.replace(text, self.space)
return self.control_regex.replace(t, self.empty)
大小写变换
class CaseNormalizer:
def __init__(self, do_lowercase: bool = False, unicode_norm: str = '') -> None:
self.do_lowercase: bool = do_lowercase
def __call__(self, text: bytes) -> bytes:
if self.do_lowercase:
return text.lower()
else:
return text
标点处理
class PunctuationPadding:
"""Pad a space around the punctuation."""
def __init__(self):
self.regex_pattern: matx.Regex = matx.Regex(
r"([\u0021-\u002f]|[\u003a-\u0040}]|[\u005b-\u0060}]|[\u007b-\u007e]|\p{P})")
self.replace_pattern: bytes = r" ${1} ".encode()
def __call__(self, text: bytes) -> bytes:
return self.regex_pattern.replace(text, self.replace_pattern)
1.2 基于Matx的BertTokenizer¶
下面实现了基于Matx的BertTokenizer的逻辑,并使用了上文实现的文本清洗,大小写变换等工具。
import matx
from matx.text import WordPieceTokenizer
class MatxBertTokenizer:
def __init__(self,
vocab_path: str,
lower_case: bool = False,
max_tokens_per_input: int = 256,
unk_token: str = '[UNK]'
) -> None:
"""
matx style BertTokenzier。
vocab_path: vocabulary path for tokenizer
lower_case: convert to lowercase or not
max_tokens_per_input: token length limit
unk_token: the symbol for unknown tokens
"""
self.cleaner: TextCleaner = TextCleaner()
self.normalizer: CaseNormalizer = CaseNormalizer(True)
self.punc_padding: PunctuationPadding = PunctuationPadding()
self.max_tokens_per_input: int = max_tokens_per_input
self.world_piece: Any = WordPieceTokenizer(vocab_path=vocab_path,
unk_token=unk_token,
max_bytes_per_token=max_tokens_per_input)
self.cls_id: int = self.world_piece.tokenize(['[CLS]'])[0]
self.sep_id: int = self.world_piece.tokenize(['[SEP]'])[0]
self.pad_id: int = self.world_piece.tokenize(['[PAD]'])[0]
def __call__(self, texts: List[bytes]) -> Dict[str, matx.NDArray]:
batch_input_ids: List = []
batch_input_mask: List = []
batch_segment_ids: List = []
for text in texts:
text = self.cleaner(text)
text = self.normalizer(text)
text = self.punc_padding(text)
terms: List = text.split()
tokens: List[int] = self.world_piece.tokenize(terms)
# start to create bert style input
len_tre: int = self.max_tokens_per_input - 2
input_ids: List = [self.cls_id] + tokens[:len_tre] + [self.sep_id]
input_mask: List = [1] * len(input_ids) + [0] * (self.max_tokens_per_input - len(input_ids))
input_ids = input_ids + [self.pad_id] * (self.max_tokens_per_input - len(input_ids))
segment_ids = [0] * self.max_tokens_per_input
batch_input_ids.append(input_ids)
batch_input_mask.append(input_mask)
batch_segment_ids.append(segment_ids)
res: Dict = {}
res["input_ids"] = matx.NDArray(batch_input_ids, [], "int64")
res["input_mask"] = matx.NDArray(batch_input_mask, [], "int64")
res["segment_ids"] = matx.NDArray(batch_segment_ids, [], "int64")
return res
2. 图片模态预处理¶
下面以Resnet预处理为例,实现了图片模态的预处理逻辑,涉及到的图片类OP主要有 Decode,RandomResizedCrop,CenterCrop,RandomHorizontalFlip, Normalize等
from typing import List, Dict, Any
import matx
from matx.vision.tv_transforms import Decode, RandomHorizontalFlip, \
RandomResizedCrop, CenterCrop, Normalize, Stack, Transpose, Compose
class MatxImagenetVisionProcessor:
def __init__(self, device_id: int = -1, is_train: bool = True) -> None:
self.is_train: bool = is_train
vision_ops: List = []
if is_train: # image transform for training
vision_ops = [
matx.script(Decode)(to_rgb=True),
matx.script(RandomResizedCrop)(size=[224, 224],scale=(0.08,1.0), ratio=(0.75, 1.33)),
matx.script(RandomHorizontalFlip)(),
matx.script(Normalize)(mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375]),
matx.script(Stack)(),
matx.script(Transpose)()
]
else: # image transform for evaluate
vision_ops = [
matx.script(Decode)(to_rgb=True),
matx.script(CenterCrop)(size=[224, 224]),
matx.script(Normalize)(mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375]),
matx.script(Stack)(),
matx.script(Transpose)()
]
self.vision_op: Any = matx.script(Compose)(device_id, vision_ops)
def __call__(self, images: List[bytes]) -> matx.NDArray:
return self.vision_op(images)
3. 图片/文本模态整合¶
现在可以结合文本和图片的预处理生成整体的多模态预处理逻辑
@matx.script
class MultiModalPipeline:
def __init__(self,
vocab_path: str,
lower_case: bool = False,
max_tokens_per_input: int = 256,
unk_token: str = '[UNK]',
vision_device_id: int = -1,
is_train: bool = True):
self.text_processor: Any = MatxBertTokenizer(
vocab_path, lower_case, max_tokens_per_input, unk_token
)
self.vision_processor: Any = MatxImagenetVisionProcessor(
vision_device_id, is_train
)
# the input is a batch of data
# assume each data is like {"text": "some text", "image": b"some image"}
# the output would be collated, organize the result in any format as you want
# the code below would output the processed data like
# {"images": batched_image, "input_ids": batched_input_id, "input_mask": batched_input_mask}
def __call__(self, data: List[Dict[str, Any]]) -> Dict[str, matx.NDArray]:
texts: List[str] = [item["text"] for item in data]
images: List[bytes] = [item["image"] for item in data]
processed_texts: Dict[str, matx.NDArray] = self.text_processor(texts)
processed_images: matx.NDArray = self.vision_processor(images)
res: Dict[str, matx.NDArray] = {}
for k in processed_texts:
res[k] = processed_texts[k]
res["images"] = processed_images
return res
4. PyTorch Dataloader 示例¶
有了预处理逻辑之后,即可以将其接入DataLoader,进而进行模型训练。这部分与大家习惯的写法没有什么不同,下面用假数据生成了一个Pytorch DataLoader,供大家参考。
from torch.utils.data import DataLoader
class DemoDataset:
def __init__(self, is_train=True):
# If want to run the code, please download the demo image and vocabulary file
# from github, or just replace them with your own ones
f = open("demo.jpeg","rb")
img = f.read()
f.close()
text = b"this is a demo"
self.data = {"text": text, "image": img}
self.transform = MultiModalPipeline("vocab.txt", is_train=is_train)
def __len__(self):
return 100 # some fake number
def __getitem__(self, indices):
batch_data = [self.data] * len(indices)
transformed_data = self.transform(batch_data)
res = {}
# convert each matx.NDArray to torch tensor
for k in transformed_data.keys():
res[k] = transformed_data[k].torch()
return res
if __name__ == "__main__":
dataset = DemoDataset()
loader = DataLoader(dataset)
for data in loader:
print(data["images"].shape)
print(data["input_ids"].shape)