Appearance
question:原始模型:import math import logging from functools import partial from collections import OrderedDict from copy import deepcopy import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import to_2tuple from lib.models.layers.patch_embed import PatchEmbed, PatchEmbed_event, xcorr_depthwise from .utils import combine_tokens, recover_tokens from .vit import VisionTransformer from ..layers.attn_blocks import CEBlock _logger = logging.getLogger(__name__) class VisionTransformerCE(VisionTransformer): """ Vision Transformer with candidate elimination (CE) module A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None, weight_init='', ce_loc=None, ce_keep_ratio=None): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels num_classes (int): number of classes for classification head embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set distilled (bool): model includes a distillation token and head as in DeiT models drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer weight_init: (str): weight init scheme """ # super().__init__() super().__init__() if isinstance(img_size, tuple): self.img_size = img_size else: self.img_size = to_2tuple(img_size) self.patch_size = patch_size self.in_chans = in_chans self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 2 if distilled else 1 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) self.pos_embed_event = PatchEmbed_event(in_chans=32, embed_dim=768, kernel_size=4, stride=4) # self.pos_embed_event = PatchEmbed_event(in_chans=32, embed_dim=768, kernel_size=4, stride=4) # self.pos_embed_event_z = PatchEmbed_event(in_chans=32, embed_dim=768, kernel_size=3, stride=1) # attn = CrossAttn(768, 4, 3072, 0.1, 'relu') # self.cross_attn = Iter_attn(attn, 2) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule blocks = [] ce_index = 0 self.ce_loc = ce_loc for i in range(depth): ce_keep_ratio_i = 1.0 if ce_loc is not None and i in ce_loc: ce_keep_ratio_i = ce_keep_ratio[ce_index] ce_index += 1 blocks.append( CEBlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, keep_ratio_search=ce_keep_ratio_i) ) self.blocks = nn.Sequential(*blocks) self.norm = norm_layer(embed_dim) self.init_weights(weight_init) def forward_features(self, z, x, event_z, event_x, mask_z=None, mask_x=None, ce_template_mask=None, ce_keep_rate=None, return_last_attn=False ): B, H, W = x.shape[0], x.shape[2], x.shape[3] event_z = self.pos_embed_event(event_z) # [:,:,:,:1000] event_x = self.pos_embed_event(event_x) # B 768 1024 x = self.patch_embed(x) z = self.patch_embed(z) event_z += self.pos_embed_z event_x += self.pos_embed_x z += self.pos_embed_z x += self.pos_embed_x # attention mask handling # B, H, W if mask_z is not None and mask_x is not None: mask_z = F.interpolate(mask_z[None].float(), scale_factor=1. / self.patch_size).to(torch.bool)[0] mask_z = mask_z.flatten(1).unsqueeze(-1) mask_x = F.interpolate(mask_x[None].float(), scale_factor=1. / self.patch_size).to(torch.bool)[0] mask_x = mask_x.flatten(1).unsqueeze(-1) mask_x = combine_tokens(mask_z, mask_x, mode=self.cat_mode) mask_x = mask_x.squeeze(-1) if self.add_cls_token: cls_tokens = self.cls_token.expand(B, -1, -1) cls_tokens = cls_tokens + self.cls_pos_embed if self.add_sep_seg: x += self.search_segment_pos_embed z += self.template_segment_pos_embed x = combine_tokens(z, event_z, x, event_x, mode=self.cat_mode) # 64+64+256+256=640 # x = combine_tokens(z, x, event_z, event_x, mode=self.cat_mode) # 64+64+256+256=640 if self.add_cls_token: x = torch.cat([cls_tokens, x], dim=1) x = self.pos_drop(x) lens_z = self.pos_embed_z.shape[1] lens_x = self.pos_embed_x.shape[1] global_index_t = torch.linspace(0, lens_z - 1, lens_z).to(x.device) global_index_t = global_index_t.repeat(B, 1) global_index_s = torch.linspace(0, lens_x - 1, lens_x).to(x.device) global_index_s = global_index_s.repeat(B, 1) removed_indexes_s = [] for i, blk in enumerate(self.blocks): x, global_index_t, global_index_s, removed_index_s, attn = blk(x, global_index_t, global_index_s, mask_x, ce_template_mask, ce_keep_rate) if self.ce_loc is not None and i in self.ce_loc: removed_indexes_s.append(removed_index_s) x = self.norm(x) lens_x_new = global_index_s.shape[1] lens_z_new = global_index_t.shape[1] z = x[:, :lens_z_new*2] x = x[:, lens_z_new*2:] if removed_indexes_s and removed_indexes_s[0] is not None: removed_indexes_cat = torch.cat(removed_indexes_s, dim=1) pruned_lens_x = lens_x - lens_x_new pad_x = torch.zeros([B, pruned_lens_x, x.shape[2]], device=x.device) x = torch.cat([x, pad_x], dim=1) index_all = torch.cat([global_index_s, removed_indexes_cat], dim=1) # recover original token order C = x.shape[-1] x = torch.zeros_like(x).scatter_(dim=1, index=index_all.unsqueeze(-1).expand(B, -1, C).to(torch.int64), src=x) x = recover_tokens(x, lens_z_new, lens_x, mode=self.cat_mode) x = x[:, :lens_x] # RGB head x = torch.cat([event_x, x], dim=1) # x = x[:, lens_x//2:] # event head # x = torch.cat([z, x], dim=1) # re-concatenate with the template, which may be further used by other modules # x, event_x = x[:, :lens_x//2], x[:, lens_x//2:] # x = x[:, -lens_x//2:] aux_dict = { "attn": attn, "removed_indexes_s": removed_indexes_s, # used for visualization } return x, aux_dict def forward(self, z, x, event_z, event_x, ce_template_mask=None, ce_keep_rate=None, tnc_keep_rate=None, return_last_attn=False): x, aux_dict = self.forward_features(z, x, event_z, event_x, ce_template_mask=ce_template_mask, ce_keep_rate=ce_keep_rate,) return x, aux_dict def _create_vision_transformer(pretrained=False, **kwargs): model = VisionTransformerCE(**kwargs) if pretrained: if 'npz' in pretrained: model.load_pretrained(pretrained, prefix='') else: checkpoint = torch.load(pretrained, map_location="cpu") missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=False) print('Load pretrained model from: ' + pretrained) return model def vit_base_patch16_224_ce(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) return model def vit_large_patch16_224_ce(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). """ model_kwargs = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) return model 的train loss差不多0.4,val loss差不多1.5 。修改模型① # 将 4输入分开,构建新的相同模态结合的2输入,2分支 import math import logging from functools import partial from collections import OrderedDict from copy import deepcopy import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import to_2tuple from lib.models.layers.patch_embed import PatchEmbed, PatchEmbed_event, xcorr_depthwise from .utils import combine_tokens, recover_tokens from .vit import VisionTransformer from ..layers.attn_blocks import CEBlock _logger = logging.getLogger(__name__) class VisionTransformerCE(VisionTransformer): """ Vision Transformer with candidate elimination (CE) module A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None, weight_init='', ce_loc=None, ce_keep_ratio=None): super().__init__() if isinstance(img_size, tuple): self.img_size = img_size else: self.img_size = to_2tuple(img_size) self.patch_size = patch_size self.in_chans = in_chans self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 2 if distilled else 1 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) self.pos_embed_event = PatchEmbed_event(in_chans=32, embed_dim=768, kernel_size=4, stride=4) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule blocks = [] ce_index = 0 self.ce_loc = ce_loc for i in range(depth): ce_keep_ratio_i = 1.0 if ce_loc is not None and i in ce_loc: ce_keep_ratio_i = ce_keep_ratio[ce_index] ce_index += 1 blocks.append( CEBlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, keep_ratio_search=ce_keep_ratio_i) ) self.blocks = nn.Sequential(*blocks) self.norm = norm_layer(embed_dim) self.init_weights(weight_init) def forward_features(self, z, x, event_z, event_x, mask_z=None, mask_x=None, ce_template_mask=None, ce_keep_rate=None, return_last_attn=False ): # 分支1 处理流程 B, H, W = x.shape[0], x.shape[2], x.shape[3] x = self.patch_embed(x) z = self.patch_embed(z) z += self.pos_embed_z x += self.pos_embed_x if mask_z is not None and mask_x is not None: mask_z = F.interpolate(mask_z[None].float(), scale_factor=1. / self.patch_size).to(torch.bool)[0] mask_z = mask_z.flatten(1).unsqueeze(-1) mask_x = F.interpolate(mask_x[None].float(), scale_factor=1. / self.patch_size).to(torch.bool)[0] mask_x = mask_x.flatten(1).unsqueeze(-1) mask_x = combine_tokens(mask_z, mask_x, mode=self.cat_mode) mask_x = mask_x.squeeze(-1) if self.add_cls_token: cls_tokens = self.cls_token.expand(B, -1, -1) cls_tokens = cls_tokens + self.cls_pos_embed if self.add_sep_seg: x += self.search_segment_pos_embed z += self.template_segment_pos_embed x = combine_tokens(z, x, mode=self.cat_mode) if self.add_cls_token: x = torch.cat([cls_tokens, x], dim=1) x = self.pos_drop(x) lens_z = self.pos_embed_z.shape[1] lens_x = self.pos_embed_x.shape[1] global_index_t = torch.linspace(0, lens_z - 1, lens_z).to(x.device) global_index_t = global_index_t.repeat(B, 1) global_index_s = torch.linspace(0, lens_x - 1, lens_x).to(x.device) global_index_s = global_index_s.repeat(B, 1) removed_indexes_s = [] for i, blk in enumerate(self.blocks): x, global_index_t, global_index_s, removed_index_s, attn = blk(x, global_index_t, global_index_s, mask_x, ce_template_mask, ce_keep_rate) if self.ce_loc is not None and i in self.ce_loc: removed_indexes_s.append(removed_index_s) x = self.norm(x) # # [bs, n_patch, dim] = [bs, 320, 768] 320 = 64 + 256 # # 分支2 处理流程 event_x = self.pos_embed_event(event_x) event_z = self.pos_embed_event(event_z) event_x += self.pos_embed_x event_z += self.pos_embed_z event_x = combine_tokens(event_z, event_x, mode=self.cat_mode) if self.add_cls_token: event_x = torch.cat([cls_tokens, event_x], dim=1) lens_z = self.pos_embed_z.shape[1] lens_x = self.pos_embed_x.shape[1] global_index_t1 = torch.linspace(0, lens_z - 1, lens_z).to(event_x.device) global_index_t1 = global_index_t1.repeat(B, 1) global_index_s1 = torch.linspace(0, lens_x - 1, lens_x).to(event_x.device) global_index_s1 = global_index_s1.repeat(B, 1) removed_indexes_s1 = [] for i, blk in enumerate(self.blocks): event_x, global_index_t1, global_index_s1, removed_index_s1, attn = blk(event_x, global_index_t1, global_index_s1, mask_x, ce_template_mask, ce_keep_rate) if self.ce_loc is not None and i in self.ce_loc: removed_indexes_s1.append(removed_index_s1) event_x = self.norm(event_x) # print('x.shape: ',x.shape) #x.shape: torch.Size([2, 320, 768]) # print('event_x.shape: ',event_x.shape) # event_x.shape: torch.Size([2, 320, 768]) x_cat = torch.cat([x, event_x], dim=1) # print('x_cat.shape: ',x_cat.shape) # x_cat.shape: torch.Size([2, 640, 768]) x = x_cat aux_dict = { "attn": attn, "removed_indexes_s": removed_indexes_s, # used for visualization } return x, aux_dict def forward(self, z, x, event_z, event_x, ce_template_mask=None, ce_keep_rate=None, tnc_keep_rate=None, return_last_attn=False): x, aux_dict = self.forward_features(z, x, event_z, event_x, ce_template_mask=ce_template_mask, ce_keep_rate=ce_keep_rate,) return x, aux_dict def _create_vision_transformer(pretrained=False, **kwargs): model = VisionTransformerCE(**kwargs) if pretrained: if 'npz' in pretrained: model.load_pretrained(pretrained, prefix='') else: checkpoint = torch.load(pretrained, map_location="cpu") missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=False) print('Load pretrained model from: ' + pretrained) return model def vit_base_patch16_224_ce(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) return model def vit_large_patch16_224_ce(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). """ model_kwargs = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) return model中train loss 0.5左右,val loss 1.2左右,代码② # 将 4输入分开,构建新的相同模态结合的2输入,2分支 import math import logging from functools import partial from collections import OrderedDict from copy import deepcopy import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import to_2tuple from lib.models.layers.patch_embed import PatchEmbed, PatchEmbed_event, xcorr_depthwise from .utils import combine_tokens, recover_tokens from .vit import VisionTransformer from ..layers.attn_blocks import CEBlock from .ad_counter_guide import Counter_Guide_Enhanced _logger = logging.getLogger(__name__) class VisionTransformerCE(VisionTransformer): """ Vision Transformer with candidate elimination (CE) module A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None, weight_init='', ce_loc=None, ce_keep_ratio=None): super().__init__() if isinstance(img_size, tuple): self.img_size = img_size else: self.img_size = to_2tuple(img_size) self.patch_size = patch_size self.in_chans = in_chans self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 2 if distilled else 1 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) self.pos_embed_event = PatchEmbed_event(in_chans=32, embed_dim=768, kernel_size=4, stride=4) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule blocks = [] ce_index = 0 self.ce_loc = ce_loc for i in range(depth): ce_keep_ratio_i = 1.0 if ce_loc is not None and i in ce_loc: ce_keep_ratio_i = ce_keep_ratio[ce_index] ce_index += 1 blocks.append( CEBlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, keep_ratio_search=ce_keep_ratio_i) ) self.blocks = nn.Sequential(*blocks) self.norm = norm_layer(embed_dim) self.init_weights(weight_init) # 添加交互模块counter_guide # self.counter_guide = Counter_Guide(768, 768) self.counter_guide = Counter_Guide_Enhanced(768, 768) def forward_features(self, z, x, event_z, event_x, mask_z=None, mask_x=None, ce_template_mask=None, ce_keep_rate=None, return_last_attn=False ): # 分支1 处理流程 B, H, W = x.shape[0], x.shape[2], x.shape[3] x = self.patch_embed(x) z = self.patch_embed(z) z += self.pos_embed_z x += self.pos_embed_x if mask_z is not None and mask_x is not None: mask_z = F.interpolate(mask_z[None].float(), scale_factor=1. / self.patch_size).to(torch.bool)[0] mask_z = mask_z.flatten(1).unsqueeze(-1) mask_x = F.interpolate(mask_x[None].float(), scale_factor=1. / self.patch_size).to(torch.bool)[0] mask_x = mask_x.flatten(1).unsqueeze(-1) mask_x = combine_tokens(mask_z, mask_x, mode=self.cat_mode) mask_x = mask_x.squeeze(-1) if self.add_cls_token: cls_tokens = self.cls_token.expand(B, -1, -1) cls_tokens = cls_tokens + self.cls_pos_embed if self.add_sep_seg: x += self.search_segment_pos_embed z += self.template_segment_pos_embed x = combine_tokens(z, x, mode=self.cat_mode) if self.add_cls_token: x = torch.cat([cls_tokens, x], dim=1) x = self.pos_drop(x) lens_z = self.pos_embed_z.shape[1] lens_x = self.pos_embed_x.shape[1] global_index_t = torch.linspace(0, lens_z - 1, lens_z).to(x.device) global_index_t = global_index_t.repeat(B, 1) global_index_s = torch.linspace(0, lens_x - 1, lens_x).to(x.device) global_index_s = global_index_s.repeat(B, 1) removed_indexes_s = [] # # 分支2 处理流程 event_x = self.pos_embed_event(event_x) event_z = self.pos_embed_event(event_z) event_x += self.pos_embed_x event_z += self.pos_embed_z event_x = combine_tokens(event_z, event_x, mode=self.cat_mode) if self.add_cls_token: event_x = torch.cat([cls_tokens, event_x], dim=1) lens_z = self.pos_embed_z.shape[1] lens_x = self.pos_embed_x.shape[1] global_index_t1 = torch.linspace(0, lens_z - 1, lens_z).to(event_x.device) global_index_t1 = global_index_t1.repeat(B, 1) global_index_s1 = torch.linspace(0, lens_x - 1, lens_x).to(event_x.device) global_index_s1 = global_index_s1.repeat(B, 1) removed_indexes_s1 = [] for i, blk in enumerate(self.blocks): # 第一个分支处理 x, global_index_t, global_index_s, removed_index_s, attn = blk(x, global_index_t, global_index_s, mask_x, ce_template_mask, ce_keep_rate) # 第二个分支处理 event_x, global_index_t1, global_index_s1, removed_index_s1, attn1 = blk(event_x, global_index_t1, global_index_s1, mask_x, ce_template_mask, ce_keep_rate) if self.ce_loc is not None and i in self.ce_loc: removed_indexes_s.append(removed_index_s) removed_indexes_s1.append(removed_index_s1) # 在第1层和第2层增加counter_guide模块,验证早期融合效果 if i ==0 : enhanced_x, enhanced_event_x = self.counter_guide(x, event_x) # 将增强后的特征与原特征相加 x = x + enhanced_x event_x = event_x + enhanced_event_x # 应用LayerNorm归一化处理 x = self.norm(x) event_x = self.norm(event_x) x_cat = torch.cat([event_x,x], dim=1) x = x_cat aux_dict = { "attn": attn, "attn1": attn1, "removed_indexes_s": removed_indexes_s, # used for visualization "removed_indexes_s1": removed_indexes_s1, } return x, aux_dict def forward(self, z, x, event_z, event_x, ce_template_mask=None, ce_keep_rate=None, tnc_keep_rate=None, return_last_attn=False): x, aux_dict = self.forward_features(z, x, event_z, event_x, ce_template_mask=ce_template_mask, ce_keep_rate=ce_keep_rate,) return x, aux_dict def _create_vision_transformer(pretrained=False, **kwargs): model = VisionTransformerCE(**kwargs) if pretrained: if 'npz' in pretrained: model.load_pretrained(pretrained, prefix='') else: checkpoint = torch.load(pretrained, map_location="cpu") missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=False) print('Load pretrained model from: ' + pretrained) return model def vit_base_patch16_224_ce(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) return model def vit_large_patch16_224_ce(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). """ model_kwargs = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) return model 中train loss差不多1.27 , val loss差不多1.45,
answer:基于您提供的信息,您首先提到了两种不同的模型配置和结果,并询问了代码和配置修改后的效果。您的目标是通过调整模型的配置来优化训练性能,即降低训练(train loss)和验证(val loss)损失,并且减小二者之间的差异。 # 第一种配置的表现 - **train loss**: ~0.4 - **val loss**: ~1.5 # 第二种配置的表现 - **train loss**: ~1.27 - **val loss**: ~1.45 在第二种配置的情况下,虽然train loss和val loss之间的差异较小(说明过拟合的程度有所减少),但是train loss整体有所增加,这表明模型在训练集上的拟合效果变差了。根据您的描述,主要的变化包括结构上的调整(即将输入分开构建新的相同模态结合的2输入,2分支)和超参数的调整。 # 模型改进建议 1. **数据增强**:增加数据增强的策略可以提高模型的泛化能力,降低过拟合的风险。 2. **正则化策略**:增加或调整例如Dropout、L2正则化(权重衰减)等正则化手段来降低过拟合。 3. **模型复杂度**:如果模型过于复杂而数据量有限,可以尝试简化模型,降低模型深度或减少参数。 4. **学习率调整**:尝试使用学习率衰减策略,例如余弦衰减、学习率预热,以及自适应调整学习率的优化器(如AdamW)。 5. **评估模型修改**:考虑到您对模型结构进行了修改,仔细评估这些修改对模型性能的影响。如果可能,逐步实施这些修改,并评估每一步的效果。 6. **不同训练阶段差异**:仔细分析train loss和val loss的变化趋势,专注于减轻明显的性能下降或过拟合现象的步骤。 7. **批处理大小**:考虑调整批处理大小(batch size)。较小的批量可能有助于模型学习更细粒度的特征,但也可能需要相应地调整学习率。 # 最佳实践 - **持续监控和评估**:在训练过程中持续监控train loss和val loss的变化趋势,及时识别潜在的过拟合或欠拟合问题。 - **实验跟踪**:对每次实验进行详细记录,包括使用的模型配置、数据分割、训练策略等,以便于后续分析和复现。 - **小步快跑**:逐渐进行调整并评估每次调整的影响,而不是一次性进行大量修改,从而更准确地定位影响模型性能的因素。
question:引入的counter_guide是:import torch import torch.nn as nn import torch.nn.functional as F # 为结合vit中的attention优势,设计了新的交叉注意力 class Cross_Attention(nn.Module): def __init__(self, input_channels, output_channels): super(Cross_Attention, self).__init__() self.query = nn.Linear(input_channels, output_channels) self.key = nn.Linear(input_channels, output_channels) self.value = nn.Linear(input_channels, output_channels) self.scale = output_channels ** -0.5 # 缩放因子,防止softmax的梯度消失/爆炸 def forward(self, assistant, present): # 生成查询、键和值 q = self.query(assistant) k = self.key(present) v = self.value(present) # 计算注意力分数 attention_scores = torch.bmm(q, k.transpose(1, 2)) * self.scale attention = F.softmax(attention_scores, dim=-1) # 应用注意力机制 attended_features = torch.bmm(attention, v) return attended_features class Multi_Context(nn.Module): def __init__(self, input_channels, output_channels): # 修正了def init为def init super(Multi_Context, self).__init__() self.linear1 = nn.Linear(input_channels, output_channels) self.linear2 = nn.Linear(input_channels, output_channels) self.linear3 = nn.Linear(input_channels, output_channels) self.linear_final = nn.Linear(output_channels * 3, output_channels) def forward(self, x): x1 = F.relu(self.linear1(x)) x2 = F.relu(self.linear2(x)) x3 = F.relu(self.linear3(x)) x = torch.cat([x1, x2, x3], dim=-1) # Concatenate along the feature dimension x = self.linear_final(x) return x class Adaptive_Weight(nn.Module): def __init__(self, input_channels): # 修正了def init为def init super(Adaptive_Weight, self).__init__() self.fc1 = nn.Linear(input_channels, input_channels // 4) self.fc2 = nn.Linear(input_channels // 4, input_channels) self.sigmoid = nn.Sigmoid() def forward(self, x): x_avg = torch.mean(x, dim=1) # Global Average Pooling along the ‘sequence length’ dimension weight = F.relu(self.fc1(x_avg)) weight = self.fc2(weight) weight = self.sigmoid(weight).unsqueeze(1) out = x * weight return out class Counter_Guide_Enhanced(nn.Module): def __init__(self, input_channels, output_channels): super(Counter_Guide_Enhanced, self).__init__() self.mc = Multi_Context(input_channels, output_channels) self.cross_attention = Cross_Attention(output_channels, output_channels) # 结合自适应 self.adaptive_weight = Adaptive_Weight(output_channels) def forward(self, x, event_x): # 首先使用Multi-Context增强事件特征 enhanced_event_x = self.mc(event_x) # 使用Cross-Attention进行特征交互,(x和event_x互相增强) enhanced_x = self.cross_attention(x, enhanced_event_x) enhanced_event_x_via_x = self.cross_attention(event_x, enhanced_x) # 进一步通过adaptive weight调整增强的特征权重 adaptive_enhanced_x = self.adaptive_weight(enhanced_x) adaptive_enhanced_event_x = self.adaptive_weight(enhanced_event_x_via_x) return adaptive_enhanced_x, adaptive_enhanced_event_x
answer:你提供的`Counter_Guide_Enhanced`类是一种结合多上下文增强、交叉注意力和自适应权重调整的高级特征增强模块。它在视觉变换器(ViT)模型中引入,以提高模型对特征的学习和理解能力。下面是该模块的组成部分及其功能的详细解释: # Multi_Context - `Multi_Context`模块通过三条不同的线性层路径处理输入特征`x`,每条路径都使用ReLU激活函数。然后,它将这三个路径的结果合并(沿特征维度concatenate),并通过最终的线性层,从而产生一个结合了多上下文信息的增强特征表示。 # Cross_Attention - `Cross_Attention`模块是一种交叉注意力机制,它允许两组特征(`assistant`和`present`)之间的互相交互和增强。通过为两组特征生成查询(Query)、键(Key)和值(Value),并计算注意力分数以及最终的attended features,从而实现信息的互补和增强。 # Adaptive_Weight - `Adaptive_Weight`模块使用全局平均池化和两层全连接网络来学习输入特征组的自适应权重。通过应用Sigmoid激活函数来确保权重在0到1之间,然后将这些权重应用回输入特征,从而调整了特征的重要性。 # Counter_Guide_Enhanced - `Counter_Guide_Enhanced`模块将以上所有部分结合起来,首先使用`Multi_Context`增强事件特征(`event_x`),然后通过`Cross_Attention`实现`x`和`enhanced_event_x`之间的特征交互,进一步使用`Cross_Attention`反向增强`event_x`。最后,应用`Adaptive_Weight`来调整增强后特征的权重以实现最终的特征融合。 这种设计理念在多模态或多源数据处理中是非常有价值的。例如,在视频跟踪、事件检测、图像分类等任务中,通过挖掘和融合来自不同模态(如RGB图像与事件摄像机数据)的补充信息,可以有效提高模型的性能和鲁棒性。 引入`Counter_Guide_Enhanced`模块是为了提升特征表征能力,通过增强不同来源的特征并有效整合它们,以改善模型对复杂场景的理解和预测能力。此策略可以作为处理高级视觉任务时特征提取和融合的一个强大工具。
question:但是引入该模块之后,按照原始的训练配置,train loss下降,但是下降至1.5左右,但是未加Counter_Guide_Enhanced之后,train loss下降可以下降至0.4左右
answer:引入`Counter_Guide_Enhanced`模块后训练损失较高的情况,可能由几个因素导致: # 1. 模型复杂度增加: `Counter_Guide_Enhanced`模块通过引入多个上下文、交叉注意力机制和自适应权重,使模型的复杂度提高。虽然这些机制旨在增强特征的表征能力,但同时它们也可能导致模型更难优化,特别是当初始学习率、正则化等超参数未能针对新的模块进行适当调整时。 # 解决策略: - **调整学习率**:考虑降低初始学习率,以避免在训练初期因为步长过大而错过最小损失点。 - **引入更强的正则化**:比如增加`Dropout`、`Weight Decay`,或是调整`Batch Normalization`层的使用,以防止过拟合并提高模型的泛化能力。 - **减小模块复杂度**:尝试简化`Counter_Guide_Enhanced`中的某些层,比如减少线性层的数量或输出维度,查看模型是否更易于优化。 # 2. 模块初始化不当: 模型参数的初始化方法对训练的影响很大。如果新增的模块初始化不当,可能导致训练过程不稳定,甚至导致梯度消失或爆炸。 # 解决策略: - **调整初始化策略**:为`Counter_Guide_Enhanced`中的线性层或其他参数使用适当的初始化方法,如`kaiming`初始化或`xavier`初始化。 # 3. 数据处理不当: 如果`Counter_Guide_Enhanced`模块对数据的预处理或归一化要求与其他模块不同,错误的数据处理可能导致增强特征不如预期有效。 # 解决策略: - **审查数据预处理和归一化步骤**:确保所有输入数据都按照`Counter_Guide_Enhanced`模块的要求进行适当处理。 # 4. 训练动态调整: 在训练过程中,需要对学习率、训练周期等进行动态调整,以应对模型复杂度的变化。 # 解决策略: - **使用学习率调度器**:根据训练进度调整学习率,如使用余弦退火调度器或其他动态调整策略。 - **增加训练周期**:由于模型复杂度的增加,可能需要更长的时间来收敛。考虑适当增加训练的epoch数。 # 5. 不适配的模型和任务: `Counter_Guide_Enhanced`模块可能并不适合所有任务或数据集。某些情况下,过分复杂化反而会降低模型的性能。 # 解决策略: - **任务和数据集评估**:仔细评估任务的需求是否与增强模块的特性相匹配,以及是否有必要对所有数据使用该模块。 总之,当引入`Counter_Guide_Enhanced`这样的新模块时,需要仔细考量模型复杂度、初始化策略、数据处理、以及超参数调整等因素,通过实验找到最佳的配置方案。
question:那么现在考虑简化Counter_Guide_Enhanced中的输出维度。首先代码① import torch import torch.nn as nn import torch.nn.functional as F # 为结合vit中的attention优势,设计了新的交叉注意力 class Cross_Attention(nn.Module): def __init__(self, input_channels, output_channels): super(Cross_Attention, self).__init__() self.query = nn.Linear(input_channels, output_channels) self.key = nn.Linear(input_channels, output_channels) self.value = nn.Linear(input_channels, output_channels) self.scale = output_channels ** -0.5 # 缩放因子,防止softmax的梯度消失/爆炸 def forward(self, assistant, present): # 生成查询、键和值 q = self.query(assistant) k = self.key(present) v = self.value(present) # 计算注意力分数 attention_scores = torch.bmm(q, k.transpose(1, 2)) * self.scale attention = F.softmax(attention_scores, dim=-1) # 应用注意力机制 attended_features = torch.bmm(attention, v) return attended_features class Multi_Context(nn.Module): def __init__(self, input_channels, output_channels): # 修正了def init为def init super(Multi_Context, self).__init__() self.linear1 = nn.Linear(input_channels, output_channels) self.linear2 = nn.Linear(input_channels, output_channels) self.linear3 = nn.Linear(input_channels, output_channels) self.linear_final = nn.Linear(output_channels * 3, output_channels) def forward(self, x): x1 = F.relu(self.linear1(x)) x2 = F.relu(self.linear2(x)) x3 = F.relu(self.linear3(x)) x = torch.cat([x1, x2, x3], dim=-1) # Concatenate along the feature dimension x = self.linear_final(x) return x class Adaptive_Weight(nn.Module): def __init__(self, input_channels): # 修正了def init为def init super(Adaptive_Weight, self).__init__() self.fc1 = nn.Linear(input_channels, input_channels // 4) self.fc2 = nn.Linear(input_channels // 4, input_channels) self.sigmoid = nn.Sigmoid() def forward(self, x): x_avg = torch.mean(x, dim=1) # Global Average Pooling along the ‘sequence length’ dimension weight = F.relu(self.fc1(x_avg)) weight = self.fc2(weight) weight = self.sigmoid(weight).unsqueeze(1) out = x * weight return out class Counter_Guide_Enhanced(nn.Module): def __init__(self, input_channels, output_channels): super(Counter_Guide_Enhanced, self).__init__() self.mc = Multi_Context(input_channels, output_channels) self.cross_attention = Cross_Attention(output_channels, output_channels) # 结合自适应 self.adaptive_weight = Adaptive_Weight(output_channels) def forward(self, x, event_x): # 首先使用Multi-Context增强事件特征 enhanced_event_x = self.mc(event_x) # 使用Cross-Attention进行特征交互,(x和event_x互相增强) enhanced_x = self.cross_attention(x, enhanced_event_x) enhanced_event_x_via_x = self.cross_attention(event_x, enhanced_x) # 进一步通过adaptive weight调整增强的特征权重 adaptive_enhanced_x = self.adaptive_weight(enhanced_x) adaptive_enhanced_event_x = self.adaptive_weight(enhanced_event_x_via_x) return adaptive_enhanced_x, adaptive_enhanced_event_x 调用代码是:# 将 4输入分开,构建新的相同模态结合的2输入,2分支 import math import logging from functools import partial from collections import OrderedDict from copy import deepcopy import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import to_2tuple from lib.models.layers.patch_embed import PatchEmbed, PatchEmbed_event, xcorr_depthwise from .utils import combine_tokens, recover_tokens from .vit import VisionTransformer from ..layers.attn_blocks import CEBlock from .new_counter_guide import Counter_Guide from .ad_counter_guide import Counter_Guide_Enhanced _logger = logging.getLogger(__name__) class VisionTransformerCE(VisionTransformer): """ Vision Transformer with candidate elimination (CE) module A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None, weight_init='', ce_loc=None, ce_keep_ratio=None): super().__init__() if isinstance(img_size, tuple): self.img_size = img_size else: self.img_size = to_2tuple(img_size) self.patch_size = patch_size self.in_chans = in_chans self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 2 if distilled else 1 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) self.pos_embed_event = PatchEmbed_event(in_chans=32, embed_dim=768, kernel_size=4, stride=4) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule blocks = [] ce_index = 0 self.ce_loc = ce_loc for i in range(depth): ce_keep_ratio_i = 1.0 if ce_loc is not None and i in ce_loc: ce_keep_ratio_i = ce_keep_ratio[ce_index] ce_index += 1 blocks.append( CEBlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, keep_ratio_search=ce_keep_ratio_i) ) self.blocks = nn.Sequential(*blocks) self.norm = norm_layer(embed_dim) self.init_weights(weight_init) # 添加交互模块counter_guide # self.counter_guide = Counter_Guide(768, 768) self.counter_guide = Counter_Guide_Enhanced(768, 768) def forward_features(self, z, x, event_z, event_x, mask_z=None, mask_x=None, ce_template_mask=None, ce_keep_rate=None, return_last_attn=False ): # 分支1 处理流程 B, H, W = x.shape[0], x.shape[2], x.shape[3] x = self.patch_embed(x) z = self.patch_embed(z) z += self.pos_embed_z x += self.pos_embed_x if mask_z is not None and mask_x is not None: mask_z = F.interpolate(mask_z[None].float(), scale_factor=1. / self.patch_size).to(torch.bool)[0] mask_z = mask_z.flatten(1).unsqueeze(-1) mask_x = F.interpolate(mask_x[None].float(), scale_factor=1. / self.patch_size).to(torch.bool)[0] mask_x = mask_x.flatten(1).unsqueeze(-1) mask_x = combine_tokens(mask_z, mask_x, mode=self.cat_mode) mask_x = mask_x.squeeze(-1) if self.add_cls_token: cls_tokens = self.cls_token.expand(B, -1, -1) cls_tokens = cls_tokens + self.cls_pos_embed if self.add_sep_seg: x += self.search_segment_pos_embed z += self.template_segment_pos_embed x = combine_tokens(z, x, mode=self.cat_mode) if self.add_cls_token: x = torch.cat([cls_tokens, x], dim=1) x = self.pos_drop(x) lens_z = self.pos_embed_z.shape[1] lens_x = self.pos_embed_x.shape[1] global_index_t = torch.linspace(0, lens_z - 1, lens_z).to(x.device) global_index_t = global_index_t.repeat(B, 1) global_index_s = torch.linspace(0, lens_x - 1, lens_x).to(x.device) global_index_s = global_index_s.repeat(B, 1) removed_indexes_s = [] # # 分支2 处理流程 event_x = self.pos_embed_event(event_x) event_z = self.pos_embed_event(event_z) event_x += self.pos_embed_x event_z += self.pos_embed_z event_x = combine_tokens(event_z, event_x, mode=self.cat_mode) if self.add_cls_token: event_x = torch.cat([cls_tokens, event_x], dim=1) lens_z = self.pos_embed_z.shape[1] lens_x = self.pos_embed_x.shape[1] global_index_t1 = torch.linspace(0, lens_z - 1, lens_z).to(event_x.device) global_index_t1 = global_index_t1.repeat(B, 1) global_index_s1 = torch.linspace(0, lens_x - 1, lens_x).to(event_x.device) global_index_s1 = global_index_s1.repeat(B, 1) removed_indexes_s1 = [] for i, blk in enumerate(self.blocks): # 第一个分支处理 x, global_index_t, global_index_s, removed_index_s, attn = blk(x, global_index_t, global_index_s, mask_x, ce_template_mask, ce_keep_rate) # 第二个分支处理 event_x, global_index_t1, global_index_s1, removed_index_s1, attn = blk(event_x, global_index_t1, global_index_s1, mask_x, ce_template_mask, ce_keep_rate) if self.ce_loc is not None and i in self.ce_loc: removed_indexes_s.append(removed_index_s) removed_indexes_s1.append(removed_index_s1) # 在第1层和第2层增加counter_guide模块,验证早期融合效果 if i == 0 : enhanced_x, enhanced_event_x = self.counter_guide(x, event_x) # 将增强后的特征与原特征相加 x = x + enhanced_x event_x = event_x + enhanced_event_x # 应用LayerNorm归一化处理 x = self.norm(x) event_x = self.norm(event_x) x_cat = torch.cat([event_x,x], dim=1) x = x_cat aux_dict = { "attn": attn, "removed_indexes_s": removed_indexes_s, # used for visualization } return x, aux_dict def forward(self, z, x, event_z, event_x, ce_template_mask=None, ce_keep_rate=None, tnc_keep_rate=None, return_last_attn=False): x, aux_dict = self.forward_features(z, x, event_z, event_x, ce_template_mask=ce_template_mask, ce_keep_rate=ce_keep_rate,) return x, aux_dict def _create_vision_transformer(pretrained=False, **kwargs): model = VisionTransformerCE(**kwargs) if pretrained: if 'npz' in pretrained: model.load_pretrained(pretrained, prefix='') else: checkpoint = torch.load(pretrained, map_location="cpu") missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=False) print('Load pretrained model from: ' + pretrained) return model def vit_base_patch16_224_ce(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) return model def vit_large_patch16_224_ce(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). """ model_kwargs = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer(pretrained=pretrained, **model_kwargs) return model ,那么按照要求,如何降低输出维度?
answer:要降低`Counter_Guide_Enhanced`模块中的输出维度,可以通过调整跨模态交互(Cross_Attention)和多上下文(Multi_Context)结构中的输出维度来实现。这样做旨在减少模型参数和计算复杂度,同时尝试保留或甚至提升模型性能。 # 1. 调整`Cross_Attention` 和 `Multi_Context`的输出维度 你可以通过减小`output_channels` 参数来减少`Cross_Attention`和`Multi_Context`模块的输出维度。例如,如果原始维度为768,你可以尝试减半至384或更少,根据你的需要调整。 # 2. 示例调整 我将提供一个修改过的例子,将输出维度从768减少到384: ```python class Cross_Attention(nn.Module): def __init__(self, input_channels, output_channels=384): # 默认输出维度调整为384 # 剩余部分保持不变 class Multi_Context(nn.Module): def __init__(self, input_channels, output_channels=384): # 默认输出维度调整为384 # 剩余部分保持不变 ``` # 3. 适配`Counter_Guide_Enhanced`中的交互模块 `Counter_Guide_Enhanced`模块内,确保跨注意力模块和多上下文增强模块的输入输出维度一致。如果你降低了输出维度,需要在`Counter_Guide_Enhanced`内部进行相应的维度适配。 # 4. 示例调整: ```python class Counter_Guide_Enhanced(nn.Module): def __init__(self, input_channels, output_channels=384): # 输出维度调整为384 super(Counter_Guide_Enhanced, self).__init__() self.mc = Multi_Context(input_channels, output_channels) self.cross_attention = Cross_Attention(output_channels, output_channels) self.adaptive_weight = Adaptive_Weight(output_channels) # 剩余部分保持不变 ``` # 5. 注意权衡 减少模型的输出维度可能会导致表征能力下降。因此,降低输出维度时需要考虑实验评估模型性能的权衡。建议逐步减少维度,并定期检查模型的性能指标,以找到适当的平衡点。 通过这种方式,你可以在保持模型效能的同时,降低`Counter_Guide_Enhanced`模块的计算复杂性和参数数量。