matx.vision.tv_transforms 源代码

# Copyright 2022 ByteDance Ltd. and/or its affiliates.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, List, Dict

from .flip import RandomHorizontalFlip, RandomHorizontalFlipImpl, RandomVerticalFlip, RandomVerticalFlipImpl
from .blur import GaussianBlur, GaussianBlurImpl
from .color_jitter import ColorJitter, ColorJitterImpl
from .contrast import RandomAutocontrast, RandomAutocontrastImpl
from .convert_dtype import ConvertImageDtype, ConvertImageDtypeImpl
from .decode import Decode, DecodeImpl
from .equalize import RandomEqualize, RandomEqualizeImpl
from .invert import RandomInvert, RandomInvertImpl
from .normalize import Normalize, NormalizeImpl
from .posterize import RandomPosterize, RandomPosterizeImpl
from .sharp import RandomAdjustSharpness, RandomAdjustSharpnessImpl
from .solarize import RandomSolarize, RandomSolarizeImpl
from .stack import Stack, StackImpl
from .transpose import Transpose, TransposeImpl
from .to_tensor import ToTensor, ToTensorImpl
from .crop import CenterCrop, CenterCropImpl, RandomCrop, RandomCropImpl
from .grayscale import RandomGrayscale, Grayscale
from .pad import Pad, PadImpl
from .crop import CenterCrop, CenterCropImpl, RandomCrop, RandomCropImpl
from .cvt_color import CvtColor, CvtColorImpl
from .resize import Resize, ResizeImpl, RandomResizedCrop, RandomResizedCropImpl
from .warp import RandomRotation, RandomRotationImpl, RandomAffine, RandomAffineImpl, RandomPerspective, RandomPerspectiveImpl

from .. import ASYNC, SYNC
import torch
from ._common import DeviceManager

__all__ = [
    "RandomHorizontalFlip",
    "RandomVerticalFlip",
    "GaussianBlur",
    "ColorJitter",
    "RandomAutocontrast",
    "ConvertImageDtype",
    "Decode",
    "RandomEqualize",
    "RandomInvert",
    "Normalize",
    "RandomPosterize",
    "RandomAdjustSharpness",
    "RandomSolarize",
    "Stack",
    "Transpose",
    "ToTensor",
    "CenterCrop",
    "RandomCrop",
    "RandomGrayscale",
    "Grayscale",
    "Pad",
    "CenterCrop",
    "RandomCrop",
    "CvtColor",
    "Resize",
    "RandomResizedCrop",
    "RandomRotation",
    "RandomAffine",
    "RandomPerspective",
    "DeviceManager",
    "Compose",
    "set_device"
]


[文档]class Compose(object):
[文档] def __init__(self, device_id: int, transforms: List[Any]) -> None: self.default_device_id: int = device_id self.transforms: List[Any] = [] self.device_str: Dict[int, str] = {} set_last_op_sync: bool = False op_len: int = len(transforms) for i in range(op_len): op: Any = transforms[op_len - i - 1] op_device_id: int = op.device_id() if op_device_id == -2: op_device_id = self.default_device_id op_device: Any = DeviceManager(op_device_id) op_device_str: str = self._create_device_str(op_device_id) op_has_sync: bool = op.has_sync() op_sync: int = ASYNC if op_has_sync: if not set_last_op_sync: op_sync = SYNC set_last_op_sync = True else: op_sync = op.sync() self.transforms = [op(op_device, op_device_str, op_sync)] + self.transforms
def _create_device_str(self, device_id: int) -> str: if device_id in self.device_str: return self.device_str[device_id] cur_device_str: str = "cpu" if device_id >= 0: cur_device_str = "gpu:{}".format(device_id) self.device_str[device_id] = cur_device_str return cur_device_str
[文档] def __call__(self, imgs: Any) -> Any: for t in self.transforms: imgs = t(imgs, []) return imgs
def __repr__(self) -> str: format_string = "Compose" + '(' for t in self.transforms: format_string += '\n' format_string += ' {0}'.format(t) format_string += '\n)' return format_string
[文档]class set_device():
[文档] def __init__(self, device_id): self.device_id = device_id self.ori_device_id = None
def __enter__(self): self.ori_device_id = torch.cuda.current_device() if self.device_id > 0: torch.cuda.set_device(self.device_id) return self def __exit__(self, exc_type, exc_value, exc_traceback): if self.ori_device_id: torch.cuda.set_device(self.ori_device_id)