Skip to content

DYS-Net

Bases: Module, ABC

Abstract implementation of a Davis-Yin Splitting (DYS) layer in a neural network.

Note

The singular value decomposition of the matrix \(\mathsf{A}\) is used for the projection onto the subspace of all \(\mathsf{x}\) such that \(\mathsf{Ax=b}\).

Parameters:

Name Type Description Default
A tensor

Matrix for linear system

required
b tensor

Measurement vector for linear system

required
device string

Device on which to perform computations

'mps'
alpha float

Step size for DYS updates

0.05
Source code in src/dys_opt_net.py
class DYS_opt_net(nn.Module, ABC):
    ''' Abstract implementation of a Davis-Yin Splitting (DYS) layer in a neural network.

        Note:
            The singular value decomposition of the matrix $\mathsf{A}$ is used for the
            projection onto the subspace of all $\mathsf{x}$ such that $\mathsf{Ax=b}$.

        Args:
            A (tensor):      Matrix for linear system
            b (tensor):      Measurement vector for linear system
            device (string): Device on which to perform computations
            alpha (float):   Step size for DYS updates

    '''
    def __init__(self, A, b, device='mps', alpha=0.05):
        super().__init__()
        self.device = device
        self.b = b
        self.alpha = alpha*torch.ones(1, device=self.device)
        self.A = A
        self.n1 = A.shape[0]  # Number of rows of A
        self.n2 = A.shape[1]  # Number of columns of A

        U, s, VT = torch.linalg.svd(self.A, full_matrices=False)
        self.s_inv = torch.tensor([1/sing if sing >=1e-6 else 0 for sing in s]).to(self.device)
        self.V = torch.t(VT).to(self.device)
        self.UT = torch.t(U).to(self.device)

    def _project_C1(self, x):
        ''' Projection to the non-negative orthant.

        Args:
            x (tensor): point in Euclidean space

        Returns:
            Px (tensor): projection of $\mathsf{x}$ onto nonnegative orthant

        '''
        Px = torch.clamp(x, min=0)
        return Px

    def _project_C2(self, z):
      ''' Projection to the subspace of all $\mathsf{x}$ such that $\mathsf{Ax=b}$.

        Note:
            The singular value decomposition (SVD) representation
            of the matrix $\mathsf{A}$ is used to efficiently compute
            the projection.

        Args:
            z (tensor): point in Euclidean space

        Returns:
            Pz (tensor): projection onto subspace $\mathsf{\{z : Ax = b\}}$

      '''
      res = self.A.matmul(z.permute(1,0)) - self.b.view(-1,1)
      temp = self.V.matmul(self.s_inv.view(-1,1)*self.UT.matmul(res)).permute(1,0)
      Pz = z - temp
      return Pz

    @abstractmethod
    def F(self, z, w):
        ''' Gradient of objective function. Must be defined for each problem type.

            Note:
                The parameters of $\mathsf{F}$ are stored in $\mathsf{w}$.

            Args:
                z (tensor): point in Euclidean space
                w (tensor): Parameters defining objective function.

            Returns:
                Fz (tensor): Gradient of objective function at $\mathsf{z}$
        '''
        pass

    @abstractmethod
    def _data_space_forward(self, d):
      ''' Specify the map from context d to parameters of F. This will be used during training.

            Args:
                d (tensor): Contextual data

            Returns:
                w (tensor): inference/solution to parameterized optimization problem.
      '''
      pass

    @abstractmethod
    def test_time_forward(self, d):
       '''
       Specify test time behaviour, e.g. use a combinatorial solver on the forward pass.

         Args:
                d (tensor): Contextual data

        Returns:
                x (tensor): inference/solution to parameterized optimization problem.
       '''
       pass


    def _apply_DYS(self, z, w): 
        ''' Apply a single update step from Davis-Yin Splitting. 

            Args:
                z (tensor): Point in Euclidean space
                w (tensor): Parameters defining function and its gradient

            Returns:
                z (tensor): Updated estimate of solution
        '''
        x = self._project_C1(z)
        y = self._project_C2(2.0 * x - z - self.alpha*self.F(z, w))
        z = z - x + y
        return z


    def _train_time_forward(self, d, eps=1.0e-2, max_depth=int(1e4), 
                depth_warning=True): 
        ''' Default forward behaviour during training.

            Args:
                d (tensor):           Contextual data
                eps (float);          Stopping criterion threshold
                max_depth (int):      Maximum number of DYS updates
                depth_warning (bool): Boolean for whether to print warning message when max depth reached

            Returns:
                z (tensor): P+O Inference
        '''
        with torch.no_grad():
            w = self._data_space_forward(d)
            self.depth = 0.0

            z = torch.rand((self.n2), device=self.device)
            z_prev = z.clone()      

            all_samp_conv = False
            while not all_samp_conv and self.depth < max_depth:
                z_prev = z.clone()   
                z = self._apply_DYS(z, w)
                diff_norm = torch.norm(z - z_prev) 
                diff_norm = torch.norm( diff_norm) 
                diff_norm = torch.max( diff_norm ) # take norm along the latter two dimensions then max
                self.depth += 1.0
                all_samp_conv = diff_norm <= eps

        if self.depth >= max_depth and depth_warning:
            print("\nWarning: Max Depth Reached - Break Forward Loop\n")
        if self.training:
            w = self._data_space_forward(d)
            z = self._apply_DYS(z.detach(), w)
            return self._project_C1(z)
        else:
            return self._project_C1(z).detach()


    def forward(self, d, eps=1.0e-2, max_depth=int(1e4),
                depth_warning=True):
        ''' Forward propagation of DYS-net.

            Note:
                A switch is included for using different behaviour at test/deployment. 

            Args:
                d (tensor):           Contextual data
                eps (float);          Stopping criterion threshold
                max_depth (int):      Maximum number of DYS updates
                depth_warning (bool): Boolean for whether to print warning message when max depth reached

            Returns:
                z (tensor): P+O Inference        
        '''
        if not self.training:
          return self.test_time_forward(d)

        return self._train_time_forward(d, eps, max_depth, depth_warning)

F(z, w) abstractmethod

Gradient of objective function. Must be defined for each problem type.

Note

The parameters of \(\mathsf{F}\) are stored in \(\mathsf{w}\).

Parameters:

Name Type Description Default
z tensor

point in Euclidean space

required
w tensor

Parameters defining objective function.

required

Returns:

Name Type Description
Fz tensor

Gradient of objective function at \(\mathsf{z}\)

Source code in src/dys_opt_net.py
@abstractmethod
def F(self, z, w):
    ''' Gradient of objective function. Must be defined for each problem type.

        Note:
            The parameters of $\mathsf{F}$ are stored in $\mathsf{w}$.

        Args:
            z (tensor): point in Euclidean space
            w (tensor): Parameters defining objective function.

        Returns:
            Fz (tensor): Gradient of objective function at $\mathsf{z}$
    '''
    pass

forward(d, eps=0.01, max_depth=int(10000.0), depth_warning=True)

Forward propagation of DYS-net.

Note

A switch is included for using different behaviour at test/deployment.

Parameters:

Name Type Description Default
d tensor

Contextual data

required
max_depth int

Maximum number of DYS updates

int(10000.0)
depth_warning bool

Boolean for whether to print warning message when max depth reached

True

Returns:

Name Type Description
z tensor

P+O Inference

Source code in src/dys_opt_net.py
def forward(self, d, eps=1.0e-2, max_depth=int(1e4),
            depth_warning=True):
    ''' Forward propagation of DYS-net.

        Note:
            A switch is included for using different behaviour at test/deployment. 

        Args:
            d (tensor):           Contextual data
            eps (float);          Stopping criterion threshold
            max_depth (int):      Maximum number of DYS updates
            depth_warning (bool): Boolean for whether to print warning message when max depth reached

        Returns:
            z (tensor): P+O Inference        
    '''
    if not self.training:
      return self.test_time_forward(d)

    return self._train_time_forward(d, eps, max_depth, depth_warning)

test_time_forward(d) abstractmethod

Specify test time behaviour, e.g. use a combinatorial solver on the forward pass.

Args: d (tensor): Contextual data

Returns: x (tensor): inference/solution to parameterized optimization problem.

Source code in src/dys_opt_net.py
@abstractmethod
def test_time_forward(self, d):
   '''
   Specify test time behaviour, e.g. use a combinatorial solver on the forward pass.

     Args:
            d (tensor): Contextual data

    Returns:
            x (tensor): inference/solution to parameterized optimization problem.
   '''
   pass