Skip to content

DYS-Net

Bases: Module, ABC

Abstract implementation of a Davis-Yin Splitting (DYS) layer in a neural network.

Note

The singular value decomposition of the matrix \(\mathsf{A}\) is used for the projection onto the subspace of all \(\mathsf{x}\) such that \(\mathsf{Ax=b}\).

Parameters:

Name Type Description Default
A tensor

Matrix for linear system

required
b tensor

Measurement vector for linear system

required
device string

Device on which to perform computations

'mps'
alpha float

Step size for DYS updates

0.05

Map Context Data to Gradient Parameters

Specify the map from context d to parameters of F.

Source code in src/dys_opt_net.py
@abstractmethod
def data_space_forward(self, d):
  ''' Specify the map from context d to parameters of F.
  '''
  pass


Project onto \(C_1\)

Projection to the non-negative orthant.

Parameters:

Name Type Description Default
x tensor

point in Euclidean space

required

Returns:

Name Type Description
Px tensor

projection of \(\mathsf{x}\) onto nonnegative orthant

Source code in src/dys_opt_net.py
def project_C1(self, x):
    ''' Projection to the non-negative orthant.

    Args:
        x (tensor): point in Euclidean space

    Returns:
        Px (tensor): projection of $\mathsf{x}$ onto nonnegative orthant

    '''
    Px = torch.clamp(x, min=0)
    return Px


Project onto \(C_2\)

Projection to the subspace of all \(\mathsf{x}\) such that \(\mathsf{Ax=b}\).

Note

The singular value decomposition (SVD) representation of the matrix \(\mathsf{A}\) is used to efficiently compute the projection.

Parameters:

Name Type Description Default
z tensor

point in Euclidean space

required

Returns:

Name Type Description
Pz tensor

projection onto subspace \(\mathsf{\{z : Ax = b\}}\)

Source code in src/dys_opt_net.py
def project_C2(self, z):
  ''' Projection to the subspace of all $\mathsf{x}$ such that $\mathsf{Ax=b}$.

    Note:
        The singular value decomposition (SVD) representation
        of the matrix $\mathsf{A}$ is used to efficiently compute
        the projection.

    Args:
        z (tensor): point in Euclidean space

    Returns:
        Pz (tensor): projection onto subspace $\mathsf{\{z : Ax = b\}}$

  '''
  res = self.A.matmul(z.permute(1,0)) - self.b.view(-1,1)
  temp = self.V.matmul(self.s_inv.view(-1,1)*self.UT.matmul(res)).permute(1,0)
  Pz = z - temp
  return Pz


Gradient Operation

Gradient of objective function. Must be defined for each problem type.

Note

The parameters of \(\mathsf{F}\) are stored in \(\mathsf{w}\).

Parameters:

Name Type Description Default
z tensor

point in Euclidean space

required
w tensor

Parameters defining function and its gradient

required
Source code in src/dys_opt_net.py
@abstractmethod
def F(self, z, w):
    ''' Gradient of objective function. Must be defined for each problem type.

        Note:
            The parameters of $\mathsf{F}$ are stored in $\mathsf{w}$.

        Args:
            z (tensor): point in Euclidean space
            w (tensor): Parameters defining function and its gradient
    '''
    pass


Apply Optimization Layer

Apply a single update step from Davis-Yin Splitting.

Parameters:

Name Type Description Default
z tensor

Point in Euclidean space

required
w tensor

Parameters defining function and its gradient

required

Returns:

Name Type Description
z tensor

Updated estimate of solution

Source code in src/dys_opt_net.py
def apply_DYS(self, z, w): 
    ''' Apply a single update step from Davis-Yin Splitting. 

        Args:
            z (tensor): Point in Euclidean space
            w (tensor): Parameters defining function and its gradient

        Returns:
            z (tensor): Updated estimate of solution
    '''
    x = self.project_C1(z)
    y = self.project_C2(2.0 * x - z - self.alpha*self.F(z, w))
    z = z - x + y
    return z


Train Time Forward

Default forward behaviour during training.

Parameters:

Name Type Description Default
d tensor

Contextual data

required
max_depth int

Maximum number of DYS updates

int(10000.0)
depth_warning bool

Boolean for whether to print warning message when max depth reached

True

Returns:

Name Type Description
z tensor

P+O Inference

Source code in src/dys_opt_net.py
def train_time_forward(self, d, eps=1.0e-2, max_depth=int(1e4), 
            depth_warning=True): 
    ''' Default forward behaviour during training.

        Args:
            d (tensor):           Contextual data
            eps (float);          Stopping criterion threshold
            max_depth (int):      Maximum number of DYS updates
            depth_warning (bool): Boolean for whether to print warning message when max depth reached

        Returns:
            z (tensor): P+O Inference
    '''
    with torch.no_grad():
        w = self.data_space_forward(d)
        self.depth = 0.0

        z = torch.rand((self.n2), device=self.device)
        z_prev = z.clone()      

        all_samp_conv = False
        while not all_samp_conv and self.depth < max_depth:
            z_prev = z.clone()   
            z = self.apply_DYS(z, w)
            diff_norm = torch.norm(z - z_prev) 
            diff_norm = torch.norm( diff_norm) 
            diff_norm = torch.max( diff_norm ) # take norm along the latter two dimensions then max
            self.depth += 1.0
            all_samp_conv = diff_norm <= eps

    if self.depth >= max_depth and depth_warning:
        print("\nWarning: Max Depth Reached - Break Forward Loop\n")
    if self.training:
        w = self.data_space_forward(d)
        z = self.apply_DYS(z.detach(), w)
        return self.project_C1(z)
    else:
        return self.project_C1(z).detach()


Test Time Forward

Forward propagation of DYS-net.

Note

A switch is included for using different behaviour at test/deployment.

Parameters:

Name Type Description Default
d tensor

Contextual data

required
max_depth int

Maximum number of DYS updates

int(10000.0)
depth_warning bool

Boolean for whether to print warning message when max depth reached

True

Returns:

Name Type Description
z tensor

P+O Inference

Source code in src/dys_opt_net.py
def forward(self, d, eps=1.0e-2, max_depth=int(1e4), 
            depth_warning=True):
    ''' Forward propagation of DYS-net.

        Note:
            A switch is included for using different behaviour at test/deployment. 

        Args:
            d (tensor):           Contextual data
            eps (float);          Stopping criterion threshold
            max_depth (int):      Maximum number of DYS updates
            depth_warning (bool): Boolean for whether to print warning message when max depth reached

        Returns:
            z (tensor): P+O Inference        
    '''
    if not self.training:
      return self.test_time_forward(d)

    return self.train_time_forward(d, eps, max_depth, depth_warning)