-
Notifications
You must be signed in to change notification settings - Fork 754
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add svtr decoder * svtr decoder * update Co-authored-by: gaotongxiao <[email protected]>
- Loading branch information
1 parent
53e72e4
commit 7e9f775
Showing
3 changed files
with
192 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Dict, List, Optional, Sequence, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from mmocr.models.common.dictionary import Dictionary | ||
from mmocr.registry import MODELS | ||
from mmocr.structures import TextRecogDataSample | ||
from .base import BaseDecoder | ||
|
||
|
||
@MODELS.register_module() | ||
class SVTRDecoder(BaseDecoder): | ||
"""Decoder module in `SVTR <https://arxiv.org/abs/2205.00159>`_. | ||
Args: | ||
in_channels (int): The num of input channels. | ||
dictionary (Union[Dict, Dictionary]): The config for `Dictionary` or | ||
the instance of `Dictionary`. Defaults to None. | ||
module_loss (Optional[Dict], optional): Cfg to build module_loss. | ||
Defaults to None. | ||
postprocessor (Optional[Dict], optional): Cfg to build postprocessor. | ||
Defaults to None. | ||
max_seq_len (int, optional): Maximum output sequence length :math:`T`. | ||
Defaults to 25. | ||
init_cfg (dict or list[dict], optional): Initialization configs. | ||
Defaults to None. | ||
""" | ||
|
||
def __init__(self, | ||
in_channels: int, | ||
dictionary: Union[Dict, Dictionary] = None, | ||
module_loss: Optional[Dict] = None, | ||
postprocessor: Optional[Dict] = None, | ||
max_seq_len: int = 25, | ||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: | ||
|
||
super().__init__( | ||
dictionary=dictionary, | ||
module_loss=module_loss, | ||
postprocessor=postprocessor, | ||
max_seq_len=max_seq_len, | ||
init_cfg=init_cfg) | ||
|
||
self.decoder = nn.Linear( | ||
in_features=in_channels, out_features=self.dictionary.num_classes) | ||
self.softmax = nn.Softmax(dim=-1) | ||
|
||
def forward_train( | ||
self, | ||
feat: Optional[torch.Tensor] = None, | ||
out_enc: Optional[torch.Tensor] = None, | ||
data_samples: Optional[Sequence[TextRecogDataSample]] = None | ||
) -> torch.Tensor: | ||
"""Forward for training. | ||
Args: | ||
feat (torch.Tensor, optional): The feature map from backbone of | ||
shape :math:`(N, E, H, W)`. Defaults to None. | ||
out_enc (torch.Tensor, optional): Encoder output. Defaults to None. | ||
data_samples (Sequence[TextRecogDataSample]): Batch of | ||
TextRecogDataSample, containing gt_text information. Defaults | ||
to None. | ||
Returns: | ||
Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where | ||
:math:`C` is ``num_classes``. | ||
""" | ||
assert feat.size(2) == 1, 'feature height must be 1' | ||
x = feat.squeeze(2) | ||
x = x.permute(0, 2, 1) | ||
predicts = self.decoder(x) | ||
return predicts | ||
|
||
def forward_test( | ||
self, | ||
feat: Optional[torch.Tensor] = None, | ||
out_enc: Optional[torch.Tensor] = None, | ||
data_samples: Optional[Sequence[TextRecogDataSample]] = None | ||
) -> torch.Tensor: | ||
"""Forward for testing. | ||
Args: | ||
feat (torch.Tensor, optional): The feature map from backbone of | ||
shape :math:`(N, E, H, W)`. Defaults to None. | ||
out_enc (torch.Tensor, optional): Encoder output. Defaults to None. | ||
data_samples (Sequence[TextRecogDataSample]): Batch of | ||
TextRecogDataSample, containing gt_text information. Defaults | ||
to None. | ||
Returns: | ||
Tensor: Character probabilities. of shape | ||
:math:`(N, self.max_seq_len, C)` where :math:`C` is | ||
``num_classes``. | ||
""" | ||
return self.softmax(self.forward_train(feat, out_enc, data_samples)) |
94 changes: 94 additions & 0 deletions
94
tests/test_models/test_textrecog/test_decoders/test_svtr_decoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os.path as osp | ||
import tempfile | ||
from unittest import TestCase | ||
|
||
import torch | ||
from mmengine.structures import LabelData | ||
|
||
from mmocr.models.textrecog.decoders.svtr_decoder import SVTRDecoder | ||
from mmocr.structures import TextRecogDataSample | ||
from mmocr.testing import create_dummy_dict_file | ||
|
||
|
||
class TestSVTRDecoder(TestCase): | ||
|
||
def setUp(self): | ||
gt_text_sample1 = TextRecogDataSample() | ||
gt_text = LabelData() | ||
gt_text.item = 'Hello' | ||
gt_text_sample1.gt_text = gt_text | ||
gt_text_sample1.set_metainfo(dict(valid_ratio=0.9)) | ||
|
||
gt_text_sample2 = TextRecogDataSample() | ||
gt_text = LabelData() | ||
gt_text = LabelData() | ||
gt_text.item = 'World' | ||
gt_text_sample2.gt_text = gt_text | ||
gt_text_sample2.set_metainfo(dict(valid_ratio=1.0)) | ||
|
||
self.data_info = [gt_text_sample1, gt_text_sample2] | ||
|
||
def test_init(self): | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
dict_file = osp.join(tmp_dir, 'fake_chars.txt') | ||
create_dummy_dict_file(dict_file) | ||
dict_cfg = dict( | ||
type='Dictionary', | ||
dict_file=dict_file, | ||
with_start=True, | ||
with_end=True, | ||
same_start_end=True, | ||
with_padding=True, | ||
with_unknown=True) | ||
loss_cfg = dict(type='CTCModuleLoss', letter_case='lower') | ||
SVTRDecoder( | ||
in_channels=192, dictionary=dict_cfg, module_loss=loss_cfg) | ||
|
||
def test_forward_train(self): | ||
feat = torch.randn(1, 192, 1, 25) | ||
tmp_dir = tempfile.TemporaryDirectory() | ||
max_seq_len = 25 | ||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') | ||
create_dummy_dict_file(dict_file) | ||
dict_cfg = dict( | ||
type='Dictionary', | ||
dict_file=dict_file, | ||
with_start=True, | ||
with_end=True, | ||
same_start_end=True, | ||
with_padding=True, | ||
with_unknown=True) | ||
loss_cfg = dict(type='CTCModuleLoss', letter_case='lower') | ||
decoder = SVTRDecoder( | ||
in_channels=192, | ||
dictionary=dict_cfg, | ||
module_loss=loss_cfg, | ||
max_seq_len=max_seq_len, | ||
) | ||
data_samples = decoder.module_loss.get_targets(self.data_info) | ||
output = decoder.forward_train(feat=feat, data_samples=data_samples) | ||
self.assertTupleEqual(tuple(output.shape), (1, max_seq_len, 39)) | ||
|
||
def test_forward_test(self): | ||
feat = torch.randn(1, 192, 1, 25) | ||
tmp_dir = tempfile.TemporaryDirectory() | ||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') | ||
create_dummy_dict_file(dict_file) | ||
# test diction cfg | ||
dict_cfg = dict( | ||
type='Dictionary', | ||
dict_file=dict_file, | ||
with_start=True, | ||
with_end=True, | ||
same_start_end=True, | ||
with_padding=True, | ||
with_unknown=True) | ||
loss_cfg = dict(type='CTCModuleLoss', letter_case='lower') | ||
decoder = SVTRDecoder( | ||
in_channels=192, | ||
dictionary=dict_cfg, | ||
module_loss=loss_cfg, | ||
max_seq_len=25) | ||
output = decoder.forward_test(feat=feat, data_samples=self.data_info) | ||
self.assertTupleEqual(tuple(output.shape), (1, 25, 39)) |