Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor TPS #1240

Merged
merged 6 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mmocr/models/textrecog/preprocessors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .tps_preprocessor import STN, TPStransform

__all__ = ['TPStransform', 'STN']
12 changes: 12 additions & 0 deletions mmocr/models/textrecog/preprocessors/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.model import BaseModule

from mmocr.registry import MODELS


@MODELS.register_module()
class BasePreprocessor(BaseModule):
"""Base Preprocessor class for text recognition."""

def forward(self, x, **kwargs):
return x
272 changes: 272 additions & 0 deletions mmocr/models/textrecog/preprocessors/tps_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule

from mmocr.registry import MODELS
from .base import BasePreprocessor


class TPStransform(nn.Module):
"""Implement TPS transform.

This was partially adapted from https://github.com/ayumiymk/aster.pytorch

Args:
output_image_size (tuple[int, int]): The size of the output image.
Defaults to (32, 128).
num_control_points (int): The number of control points. Defaults to 20.
margins (tuple[float, float]): The margins for control points to the
top and down side of the image. Defaults to [0.05, 0.05].
"""

def __init__(self,
output_image_size: Tuple[int, int] = (32, 100),
num_control_points: int = 20,
margins: Tuple[float, float] = [0.05, 0.05]) -> None:
super().__init__()
self.output_image_size = output_image_size
self.num_control_points = num_control_points
self.margins = margins
self.target_height, self.target_width = output_image_size

# build output control points
target_control_points = self._build_output_control_points(
num_control_points, margins)
N = num_control_points

# create padded kernel matrix
forward_kernel = torch.zeros(N + 3, N + 3)
target_control_partial_repr = self._compute_partial_repr(
target_control_points, target_control_points)
forward_kernel[:N, :N].copy_(target_control_partial_repr)
forward_kernel[:N, -3].fill_(1)
forward_kernel[-3, :N].fill_(1)
forward_kernel[:N, -2:].copy_(target_control_points)
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))

# compute inverse matrix
inverse_kernel = torch.inverse(forward_kernel)

# create target coordinate matrix
HW = self.target_height * self.target_width
tgt_coord = list(
itertools.product(
range(self.target_height), range(self.target_width)))
tgt_coord = torch.Tensor(tgt_coord)
Y, X = tgt_coord.split(1, dim=1)
Y = Y / (self.target_height - 1)
X = X / (self.target_width - 1)
tgt_coord = torch.cat([X, Y], dim=1)
tgt_coord_partial_repr = self._compute_partial_repr(
tgt_coord, target_control_points)
tgt_coord_repr = torch.cat(
[tgt_coord_partial_repr,
torch.ones(HW, 1), tgt_coord], dim=1)

# register precomputed matrices
self.register_buffer('inverse_kernel', inverse_kernel)
self.register_buffer('padding_matrix', torch.zeros(3, 2))
self.register_buffer('target_coordinate_repr', tgt_coord_repr)
self.register_buffer('target_control_points', target_control_points)

def forward(self, input: torch.Tensor,
source_control_points: torch.Tensor) -> torch.Tensor:
"""Forward function of the TPS block.

Args:
input (Tensor): The input image.
source_control_points (Tensor): The control points of the source
image of shape (N, self.num_control_points, 2).
Returns:
Tensor: The output image after TPS transform.
"""
assert source_control_points.ndimension() == 3
assert source_control_points.size(1) == self.num_control_points
assert source_control_points.size(2) == 2
batch_size = source_control_points.size(0)

Y = torch.cat([
source_control_points,
self.padding_matrix.expand(batch_size, 3, 2)
], 1)
mapping_matrix = torch.matmul(self.inverse_kernel, Y)
source_coordinate = torch.matmul(self.target_coordinate_repr,
mapping_matrix)

grid = source_coordinate.view(-1, self.target_height,
self.target_width, 2)
grid = torch.clamp(grid, 0, 1)
grid = 2.0 * grid - 1.0
output_maps = self._grid_sample(input, grid, canvas=None)
return output_maps

def _grid_sample(self,
input: torch.Tensor,
grid: torch.Tensor,
canvas: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Sample the input image at the given grid.

Args:
input (Tensor): The input image.
grid (Tensor): The grid to sample the input image.
canvas (Optional[Tensor]): The canvas to store the output image.
Returns:
Tensor: The sampled image.
"""
output = F.grid_sample(input, grid, align_corners=True)
if canvas is None:
return output
else:
input_mask = input.data.new(input.size()).fill_(1)
output_mask = F.grid_sample(input_mask, grid, align_corners=True)
padded_output = output * output_mask + canvas * (1 - output_mask)
return padded_output

def _compute_partial_repr(self, input_points: torch.Tensor,
control_points: torch.Tensor) -> torch.Tensor:
"""Compute the partial representation matrix.

Args:
input_points (Tensor): The input points.
control_points (Tensor): The control points.
Returns:
Tensor: The partial representation matrix.
"""
N = input_points.size(0)
M = control_points.size(0)
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(
1, M, 2)
pairwise_diff_square = pairwise_diff * pairwise_diff
pairwise_dist = pairwise_diff_square[:, :,
0] + pairwise_diff_square[:, :, 1]
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
mask = repr_matrix != repr_matrix
repr_matrix.masked_fill_(mask, 0)
return repr_matrix

# output_ctrl_pts are specified, according to our task.
def _build_output_control_points(self, num_control_points: torch.Tensor,
margins: Tuple[float,
float]) -> torch.Tensor:
"""Build the output control points.

The output points will be fix at
top and down side of the image.
Args:
num_control_points (Tensor): The number of control points.
margins (Tuple[float, float]): The margins for control points to
the top and down side of the image.
Returns:
Tensor: The output control points.
"""
margin_x, margin_y = margins
num_ctrl_pts_per_side = num_control_points // 2
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x,
num_ctrl_pts_per_side)
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom],
axis=0)
output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
return output_ctrl_pts


@MODELS.register_module()
class STN(BasePreprocessor):
"""Implement STN module in ASTER: An Attentional Scene Text Recognizer with
Flexible Rectification
(https://ieeexplore.ieee.org/abstract/document/8395027/)

Args:
in_channels (int): The number of input channels.
resized_image_size (Tuple[int, int]): The resized image size. The input
image will be downsampled to have a better recitified result.
output_image_size: The size of the output image for TPS. Defaults to
(32, 100).
num_control_points: The number of control points. Defaults to 20.
margins: The margins for control points to the top and down side of the
image for TPS. Defaults to [0.05, 0.05].
"""

def __init__(self,
in_channels: int,
resized_image_size: Tuple[int, int] = (32, 64),
output_image_size: Tuple[int, int] = (32, 100),
num_control_points: int = 20,
margins: Tuple[float, float] = [0.05, 0.05],
init_cfg: Optional[Union[Dict, List[Dict]]] = [
dict(type='Xavier', layer='Conv2d'),
dict(type='Constant', val=1, layer='BatchNorm2d'),
]):
super().__init__(init_cfg=init_cfg)
self.resized_image_size = resized_image_size
self.num_control_points = num_control_points
self.tps = TPStransform(output_image_size, num_control_points, margins)
self.stn_convnet = nn.Sequential(
ConvModule(in_channels, 32, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(32, 64, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(64, 128, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(128, 256, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(256, 256, 3, 1, 1, norm_cfg=dict(type='BN')),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvModule(256, 256, 3, 1, 1, norm_cfg=dict(type='BN')),
)

self.stn_fc1 = nn.Sequential(
nn.Linear(2 * 256, 512), nn.BatchNorm1d(512),
nn.ReLU(inplace=True))
self.stn_fc2 = nn.Linear(512, num_control_points * 2)
self.init_stn(self.stn_fc2)

def init_stn(self, stn_fc2: nn.Linear) -> None:
"""Initialize the output linear layer of stn, so that the initial
source point will be at the top and down side of the image, which will
help to optimize.

Args:
stn_fc2 (nn.Linear): The output linear layer of stn.
"""
margin = 0.01
sampling_num_per_side = int(self.num_control_points / 2)
ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom],
axis=0).astype(np.float32)
stn_fc2.weight.data.zero_()
stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1)

def forward(self, img: torch.Tensor) -> torch.Tensor:
"""Forward function of STN.

Args:
img (Tensor): The input image tensor.

Returns:
Tensor: The rectified image tensor.
"""
resize_img = F.interpolate(
img, self.resized_image_size, mode='bilinear', align_corners=True)
points = self.stn_convnet(resize_img)
batch_size, _, _, _ = points.size()
points = points.view(batch_size, -1)
img_feat = self.stn_fc1(points)
points = self.stn_fc2(0.1 * img_feat)
points = points.view(-1, self.num_control_points, 2)

transformd_image = self.tps(img, points)
return transformd_image
26 changes: 26 additions & 0 deletions tests/models/textrecog/test_preprocessors/test_tps_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import torch

from mmocr.models.textrecog.preprocessors import STN, TPStransform


class TestTPS(TestCase):

def test_tps_transform(self):
tps = TPStransform(output_image_size=(32, 100), num_control_points=20)
image = torch.rand(2, 3, 32, 64)
control_points = torch.rand(2, 20, 2)
transformed_image = tps(image, control_points)
self.assertEqual(transformed_image.shape, (2, 3, 32, 100))

def test_stn(self):
stn = STN(
in_channels=3,
resized_image_size=(32, 64),
output_image_size=(32, 100),
num_control_points=20)
image = torch.rand(2, 3, 64, 256)
transformed_image = stn(image)
self.assertEqual(transformed_image.shape, (2, 3, 32, 100))