import math
import torch
import torch.nn.functional as F
from torch import nn
from .backbone import resnet
from .backbone import densenet
from base import BaseModel
from utils.helpers import initialize_weights, set_trainable
from itertools import chain
import numpy as np
from .repconvs import RepConv_dict

class _PSPModule(nn.Module):
    def __init__(self, in_channels, bin_sizes, norm_layer, deploy=False):
        self.deploy = deploy
        super(_PSPModule, self).__init__()
        out_channels = in_channels // len(bin_sizes)
        self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s, norm_layer) 
                                                        for b_s in bin_sizes])
        self.bottleneck = nn.Sequential(
            RepConv(in_channels+(out_channels * len(bin_sizes)), out_channels, 
                                    kernel_size=3, padding=1, bias=False, deploy=self.deploy),

    def _make_stages(self, in_channels, out_channels, bin_sz, norm_layer):
        prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
        conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        bn = norm_layer(out_channels)
        relu = nn.ReLU(inplace=True)
        return nn.Sequential(prior, conv, bn, relu)
    def forward(self, features):
        h, w = features.size()[2], features.size()[3]
        pyramids = [features]
        pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear', 
                                        align_corners=True) for stage in self.stages])
        output = self.bottleneck(torch.cat(pyramids, dim=1))
        return output

class RepPSP(BaseModel):
    def __init__(self, num_classes, deploy, repconv=None, in_channels=3, backbone='resnet152', pretrained=True, use_aux=True, 
                freeze_bn=False, freeze_backbone=False):
        super(RepPSP, self).__init__()
        global RepConv
        RepConv = RepConv_dict[repconv]
        norm_layer = nn.BatchNorm2d
        model = getattr(resnet, backbone)(pretrained, norm_layer=norm_layer)
        m_out_sz = model.fc.in_features
        self.use_aux = use_aux 
        self.deploy = deploy

        self.initial = nn.Sequential(*list(model.children())[:4])
        if in_channels != 3:
            self.initial[0] = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.initial = nn.Sequential(*self.initial)
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4

        self.master_branch = nn.Sequential(
            _PSPModule(m_out_sz, bin_sizes=[1, 2, 3, 6], norm_layer=norm_layer, deploy=self.deploy),
            nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1)

        self.auxiliary_branch = nn.Sequential(
            RepConv(m_out_sz//2, m_out_sz//4, kernel_size=3, padding=1, bias=False, deploy=self.deploy),
            nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1)

        initialize_weights(self.master_branch, self.auxiliary_branch)
        if freeze_bn: self.freeze_bn()
        if freeze_backbone: 
            set_trainable([self.initial, self.layer1, self.layer2, self.layer3, self.layer4], False)

    def forward(self, x):
        input_size = (x.size()[2], x.size()[3])
        x = self.initial(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x_aux = self.layer3(x)
        x = self.layer4(x_aux)

        output = self.master_branch(x)
        output = F.interpolate(output, size=input_size, mode='bilinear')
        output = output[:, :, :input_size[0], :input_size[1]]

        if self.training and self.use_aux:
            aux = self.auxiliary_branch(x_aux)
            aux = F.interpolate(aux, size=input_size, mode='bilinear')
            aux = aux[:, :, :input_size[0], :input_size[1]]
            return output, aux
        return output

    def get_backbone_params(self):
        return chain(self.initial.parameters(), self.layer1.parameters(), self.layer2.parameters(), 
                   self.layer3.parameters(), self.layer4.parameters())

    def get_decoder_params(self):
        return chain(self.master_branch.parameters(), self.auxiliary_branch.parameters())

    def freeze_bn(self):
        for module in self.modules():
            if isinstance(module, nn.BatchNorm2d): module.eval()

class RepPSPDense(BaseModel):
    '''PSP with dense net as the backbone'''
    def __init__(self, num_classes, deploy=False, in_channels=3, backbone='densenet201', pretrained=True, use_aux=True, freeze_bn=False, **_):
        super(RepPSPDense, self).__init__()
        self.use_aux = use_aux
        self.deploy = deploy
        model = getattr(densenet, backbone)(pretrained)
        m_out_sz = model.classifier.in_features
        aux_out_sz = model.features.transition3.conv.out_channels

        if not pretrained or in_channels != 3:
            # If we're training from scratch, better to use 3x3 convs 
            block0 = [nn.Conv2d(in_channels, 64, 3, stride=2, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True)]
                [nn.Conv2d(64, 64, 3, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True)] * 2
            self.block0 = nn.Sequential(
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
            self.block0 = nn.Sequential(*list(model.features.children())[:4])
        self.block1 = model.features.denseblock1
        self.block2 = model.features.denseblock2
        self.block3 = model.features.denseblock3
        self.block4 = model.features.denseblock4

        self.transition1 = model.features.transition1
        # No pooling
        self.transition2 = nn.Sequential(
        self.transition3 = nn.Sequential(

        for n, m in self.block3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding = (2,2), (2,2)
        for n, m in self.block4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding = (4,4), (4,4)

        self.master_branch = nn.Sequential(
            _PSPModule(m_out_sz, bin_sizes=[1, 2, 3, 6], norm_layer=nn.BatchNorm2d, deploy=self.deploy),
            nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1)

        self.auxiliary_branch = nn.Sequential(
            RepConv(aux_out_sz, m_out_sz//4, kernel_size=3, padding=1, bias=False, deploy=self.deploy),
            nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1)

        initialize_weights(self.master_branch, self.auxiliary_branch)
        if freeze_bn: self.freeze_bn()

    def forward(self, x):
        input_size = (x.size()[2], x.size()[3])

        x = self.block0(x)
        x = self.block1(x)
        x = self.transition1(x)
        x = self.block2(x)
        x = self.transition2(x)
        x = self.block3(x)
        x_aux = self.transition3(x)
        x = self.block4(x_aux)

        output = self.master_branch(x)
        output = F.interpolate(output, size=input_size, mode='bilinear')

        if self.training and self.use_aux:
            aux = self.auxiliary_branch(x_aux)
            aux = F.interpolate(aux, size=input_size, mode='bilinear')
            return output, aux
        return output

    def get_backbone_params(self):
        return chain(self.block0.parameters(), self.block1.parameters(), self.block2.parameters(), 
                   self.block3.parameters(), self.transition1.parameters(), self.transition2.parameters(),

    def get_decoder_params(self):
        return chain(self.master_branch.parameters(), self.auxiliary_branch.parameters())

    def freeze_bn(self):
        for module in self.modules():
            if isinstance(module, nn.BatchNorm2d): module.eval()


RepConv_dict = {
    'acb': ACB, 
    'dbb': DBB,
    'dbb2': DBB2,
    'dbb3': DBB3,
    'dbb4': DBB4,
    'repvgg': RepVGG


repvgg: 1x1 + 3x3 + identity
ddb:    k_origin + k_1x1 + k_1x1_kxk_merged + k_1x1_avg_merged
ddb2:   k_origin + k_1x1 + k_1x1_kxk_merged
ddb3:   k_origin + k_1x1 + k_1x1_kxk_merged + identity
ddb4:   k_origin + k_1x1 + k_1x1_kxk_merged + k_1x1_avg_merged + identity 
acb:    1x3 + 3x1 + 3x3


  "name": "PSPNet",         // training session name
  "n_gpu": 1,               // number of GPUs to use for training.
  "use_synch_bn": true,     // Using Synchronized batchnorm (for multi-GPU usage)

    "arch": {
        "type": "PSPNet",   // name of model architecture to train
        "args": {
            "backbone": "resnet50",     // encoder type type
            "freeze_bn": false,         // When fine tuning the model this can be used
            "freeze_backbone": false,   // In this case only the decoder is trained
            "repconv": "repvgg"         // selecting the Repconv 

    "train_loader": {
        "type": "VOC",          // Selecting data loader
            "data_dir": "data/",  // dataset path
            "batch_size": 32,     // batch size
            "augment": true,      // Use data augmentation
            "crop_size": 380,     // Size of the random crop after rescaling
            "shuffle": true,
            "base_size": 400,     // The image is resized to base_size, then randomly croped
            "scale": true,        // Random rescaling between 0.5 and 2 before croping
            "flip": true,         // Random H-FLip
            "rotate": true,       // Random rotation between 10 and -10 degrees
            "blur": true,         // Adding a slight amount of blut to the image
            "split": "train_aug", // Split to use, depend of the dataset
            "num_workers": 8

    "val_loader": {     // Same for val, but no data augmentation, only a center crop
        "type": "VOC",
            "data_dir": "data/",
            "batch_size": 32,
            "crop_size": 480,
            "val": true,
            "split": "val",
            "num_workers": 4

    "optimizer": {
        "type": "SGD",
        "differential_lr": true,      // Using lr/10 for the backbone, and lr for the rest
            "lr": 0.01,               // Learning rate
            "weight_decay": 1e-4,     // Weight decay
            "momentum": 0.9

    "loss": "CrossEntropyLoss2d",     // Loss (see utils/losses.py)
    "ignore_index": 255,              // Class to ignore (must be set to -1 for ADE20K) dataset
    "lr_scheduler": {   
        "type": "Poly",               // Learning rate scheduler (Poly or OneCycle)
        "args": {}

    "trainer": {
        "epochs": 80,                 // Number of training epochs
        "save_dir": "saved/",         // Checkpoints are saved in save_dir/models/
        "save_period": 10,            // Saving chechpoint each 10 epochs
        "monitor": "max Mean_IoU",    // Mode and metric for model performance 
        "early_stop": 10,             // Number of epochs to wait before early stoping (0 to disable)
        "tensorboard": true,        // Enable tensorboard visualization
        "log_dir": "saved/runs",
        "log_per_iter": 20,         

        "val": true,
        "val_per_epochs": 5         // Run validation each 5 epochs


FCN8, UNet, UNetResnet, SegNet, SegResNet, ENet, RepENet, GCN, UperNet, PSPNet, PSPDenseNet, RepPSP, RepPSPDense, DeepLab,RepDeepLab, DeepLab_DUC_HDC, RepDUCHDC





python train.py --config config.json


pip install -r requirements.txt


