Skip to content

Commit

Permalink
[Feature] Add svtr decoder (#1448)
Browse files Browse the repository at this point in the history
* add svtr decoder

* svtr decoder

* update

Co-authored-by: gaotongxiao <[email protected]>
  • Loading branch information
willpat1213 and gaotongxiao authored Dec 30, 2022
1 parent 53e72e4 commit 7e9f775
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mmocr/models/textrecog/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from .sar_decoder import ParallelSARDecoder, SequentialSARDecoder
from .sar_decoder_with_bs import ParallelSARDecoderWithBS
from .sequence_attention_decoder import SequenceAttentionDecoder
from .svtr_decoder import SVTRDecoder

__all__ = [
'CRNNDecoder', 'ParallelSARDecoder', 'SequentialSARDecoder',
'ParallelSARDecoderWithBS', 'NRTRDecoder', 'BaseDecoder',
'SequenceAttentionDecoder', 'PositionAttentionDecoder',
'ABILanguageDecoder', 'ABIVisionDecoder', 'MasterDecoder',
'RobustScannerFuser', 'ABIFuser', 'ASTERDecoder'
'RobustScannerFuser', 'ABIFuser', 'SVTRDecoder', 'ASTERDecoder'
]
96 changes: 96 additions & 0 deletions mmocr/models/textrecog/decoders/svtr_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Union

import torch
import torch.nn as nn

from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS
from mmocr.structures import TextRecogDataSample
from .base import BaseDecoder


@MODELS.register_module()
class SVTRDecoder(BaseDecoder):
"""Decoder module in `SVTR <https://arxiv.org/abs/2205.00159>`_.
Args:
in_channels (int): The num of input channels.
dictionary (Union[Dict, Dictionary]): The config for `Dictionary` or
the instance of `Dictionary`. Defaults to None.
module_loss (Optional[Dict], optional): Cfg to build module_loss.
Defaults to None.
postprocessor (Optional[Dict], optional): Cfg to build postprocessor.
Defaults to None.
max_seq_len (int, optional): Maximum output sequence length :math:`T`.
Defaults to 25.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
"""

def __init__(self,
in_channels: int,
dictionary: Union[Dict, Dictionary] = None,
module_loss: Optional[Dict] = None,
postprocessor: Optional[Dict] = None,
max_seq_len: int = 25,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:

super().__init__(
dictionary=dictionary,
module_loss=module_loss,
postprocessor=postprocessor,
max_seq_len=max_seq_len,
init_cfg=init_cfg)

self.decoder = nn.Linear(
in_features=in_channels, out_features=self.dictionary.num_classes)
self.softmax = nn.Softmax(dim=-1)

def forward_train(
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> torch.Tensor:
"""Forward for training.
Args:
feat (torch.Tensor, optional): The feature map from backbone of
shape :math:`(N, E, H, W)`. Defaults to None.
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
data_samples (Sequence[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
Returns:
Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where
:math:`C` is ``num_classes``.
"""
assert feat.size(2) == 1, 'feature height must be 1'
x = feat.squeeze(2)
x = x.permute(0, 2, 1)
predicts = self.decoder(x)
return predicts

def forward_test(
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> torch.Tensor:
"""Forward for testing.
Args:
feat (torch.Tensor, optional): The feature map from backbone of
shape :math:`(N, E, H, W)`. Defaults to None.
out_enc (torch.Tensor, optional): Encoder output. Defaults to None.
data_samples (Sequence[TextRecogDataSample]): Batch of
TextRecogDataSample, containing gt_text information. Defaults
to None.
Returns:
Tensor: Character probabilities. of shape
:math:`(N, self.max_seq_len, C)` where :math:`C` is
``num_classes``.
"""
return self.softmax(self.forward_train(feat, out_enc, data_samples))
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase

import torch
from mmengine.structures import LabelData

from mmocr.models.textrecog.decoders.svtr_decoder import SVTRDecoder
from mmocr.structures import TextRecogDataSample
from mmocr.testing import create_dummy_dict_file


class TestSVTRDecoder(TestCase):

def setUp(self):
gt_text_sample1 = TextRecogDataSample()
gt_text = LabelData()
gt_text.item = 'Hello'
gt_text_sample1.gt_text = gt_text
gt_text_sample1.set_metainfo(dict(valid_ratio=0.9))

gt_text_sample2 = TextRecogDataSample()
gt_text = LabelData()
gt_text = LabelData()
gt_text.item = 'World'
gt_text_sample2.gt_text = gt_text
gt_text_sample2.set_metainfo(dict(valid_ratio=1.0))

self.data_info = [gt_text_sample1, gt_text_sample2]

def test_init(self):
with tempfile.TemporaryDirectory() as tmp_dir:
dict_file = osp.join(tmp_dir, 'fake_chars.txt')
create_dummy_dict_file(dict_file)
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=True)
loss_cfg = dict(type='CTCModuleLoss', letter_case='lower')
SVTRDecoder(
in_channels=192, dictionary=dict_cfg, module_loss=loss_cfg)

def test_forward_train(self):
feat = torch.randn(1, 192, 1, 25)
tmp_dir = tempfile.TemporaryDirectory()
max_seq_len = 25
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
create_dummy_dict_file(dict_file)
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=True)
loss_cfg = dict(type='CTCModuleLoss', letter_case='lower')
decoder = SVTRDecoder(
in_channels=192,
dictionary=dict_cfg,
module_loss=loss_cfg,
max_seq_len=max_seq_len,
)
data_samples = decoder.module_loss.get_targets(self.data_info)
output = decoder.forward_train(feat=feat, data_samples=data_samples)
self.assertTupleEqual(tuple(output.shape), (1, max_seq_len, 39))

def test_forward_test(self):
feat = torch.randn(1, 192, 1, 25)
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
create_dummy_dict_file(dict_file)
# test diction cfg
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=True)
loss_cfg = dict(type='CTCModuleLoss', letter_case='lower')
decoder = SVTRDecoder(
in_channels=192,
dictionary=dict_cfg,
module_loss=loss_cfg,
max_seq_len=25)
output = decoder.forward_test(feat=feat, data_samples=self.data_info)
self.assertTupleEqual(tuple(output.shape), (1, 25, 39))

0 comments on commit 7e9f775

Please sign in to comment.