Package minimise :: Module steihaug_cg
[hide private]
[frames] | no frames]

Source Code for Module minimise.steihaug_cg

  1  ############################################################################### 
  2  #                                                                             # 
  3  # Copyright (C) 2003 Edward d'Auvergne                                        # 
  4  #                                                                             # 
  5  # This file is part of the program relax.                                     # 
  6  #                                                                             # 
  7  # relax is free software; you can redistribute it and/or modify               # 
  8  # it under the terms of the GNU General Public License as published by        # 
  9  # the Free Software Foundation; either version 2 of the License, or           # 
 10  # (at your option) any later version.                                         # 
 11  #                                                                             # 
 12  # relax is distributed in the hope that it will be useful,                    # 
 13  # but WITHOUT ANY WARRANTY; without even the implied warranty of              # 
 14  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               # 
 15  # GNU General Public License for more details.                                # 
 16  #                                                                             # 
 17  # You should have received a copy of the GNU General Public License           # 
 18  # along with relax; if not, write to the Free Software                        # 
 19  # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA   # 
 20  #                                                                             # 
 21  ############################################################################### 
 22   
 23   
 24  from LinearAlgebra import inverse 
 25  from Numeric import Float64, dot, matrixmultiply, outerproduct, sqrt, zeros 
 26   
 27  from newton import Newton 
 28  from base_classes import Min, Trust_region 
 29   
 30   
31 -def steihaug(func=None, dfunc=None, d2func=None, args=(), x0=None, func_tol=1e-25, grad_tol=None, maxiter=1e6, epsilon=1e-8, delta_max=1e5, delta0=1.0, eta=0.2, full_output=0, print_flag=0, print_prefix=""):
32 """Steihaug conjugate-gradient trust region algorithm. 33 34 Page 75 from 'Numerical Optimization' by Jorge Nocedal and Stephen J. Wright, 1999, 2nd ed. 35 36 The CG-Steihaug algorithm is: 37 38 epsilon > 0 39 p0 = 0, r0 = g, d0 = -r0 40 if ||r0|| < epsilon: 41 return p = p0 42 while 1: 43 if djT.B.dj <= 0: 44 Find tau such that p = pj + tau.dj minimises m(p) in (4.9) and satisfies ||p|| = delta 45 return p 46 aj = rjT.rj / djT.B.dj 47 pj+1 = pj + aj.dj 48 if ||pj+1|| >= delta: 49 Find tau such that p = pj + tau.dj satisfies ||p|| = delta 50 return p 51 rj+1 = rj + aj.B.dj 52 if ||rj+1|| < epsilon.||r0||: 53 return p = pj+1 54 bj+1 = rj+1T.rj+1 / rjT.rj 55 dj+1 = rj+1 + bj+1.dj 56 """ 57 58 if print_flag: 59 if print_flag >= 2: 60 print print_prefix 61 print print_prefix 62 print print_prefix + "CG-Steihaug minimisation" 63 print print_prefix + "~~~~~~~~~~~~~~~~~~~~~~~~" 64 min = Steihaug(func, dfunc, d2func, args, x0, func_tol, grad_tol, maxiter, epsilon, delta_max, delta0, eta, full_output, print_flag, print_prefix) 65 results = min.minimise() 66 return results
67 68
69 -class Steihaug(Min, Trust_region, Newton):
70 - def __init__(self, func, dfunc, d2func, args, x0, func_tol, grad_tol, maxiter, epsilon, delta_max, delta0, eta, full_output, print_flag, print_prefix):
71 """Class for Steihaug conjugate-gradient trust region minimisation specific functions. 72 73 Unless you know what you are doing, you should call the function 'steihaug' rather than 74 using this class. 75 """ 76 77 # Function arguments. 78 self.func = func 79 self.dfunc = dfunc 80 self.d2func = d2func 81 self.args = args 82 self.xk = x0 83 self.func_tol = func_tol 84 self.grad_tol = grad_tol 85 self.maxiter = maxiter 86 self.full_output = full_output 87 self.print_flag = print_flag 88 self.print_prefix = print_prefix 89 self.epsilon = epsilon 90 self.delta_max = delta_max 91 self.delta = delta0 92 self.eta = eta 93 94 # Initialise the function, gradient, and Hessian evaluation counters. 95 self.f_count = 0 96 self.g_count = 0 97 self.h_count = 0 98 99 # Initialise the warning string. 100 self.warning = None 101 102 # Set the convergence test function. 103 self.setup_conv_tests() 104 105 # Newton setup function. 106 self.setup_newton() 107 108 # Set the update function. 109 self.specific_update = self.update_newton
110 111
112 - def get_pk(self):
113 """The CG-Steihaug algorithm.""" 114 115 # Initial values at j = 0. 116 self.pj = zeros(len(self.xk), Float64) 117 self.rj = self.dfk * 1.0 118 self.dj = -self.dfk * 1.0 119 self.B = self.d2fk * 1.0 120 len_r0 = sqrt(dot(self.rj, self.rj)) 121 length_test = self.epsilon * len_r0 122 123 if self.print_flag >= 2: 124 print self.print_prefix + "p0: " + `self.pj` 125 print self.print_prefix + "r0: " + `self.rj` 126 print self.print_prefix + "d0: " + `self.dj` 127 128 if len_r0 < self.epsilon: 129 if self.print_flag >= 2: 130 print self.print_prefix + "len rj < epsilon." 131 return self.pj 132 133 # Iterate over j. 134 j = 0 135 while 1: 136 # The curvature. 137 curv = dot(self.dj, dot(self.B, self.dj)) 138 if self.print_flag >= 2: 139 print self.print_prefix + "\nIteration j = " + `j` 140 print self.print_prefix + "Curv: " + `curv` 141 142 # First test. 143 if curv <= 0.0: 144 tau = self.get_tau() 145 if self.print_flag >= 2: 146 print self.print_prefix + "curv <= 0.0, therefore tau = " + `tau` 147 return self.pj + tau * self.dj 148 149 aj = dot(self.rj, self.rj) / curv 150 self.pj_new = self.pj + aj * self.dj 151 if self.print_flag >= 2: 152 print self.print_prefix + "aj: " + `aj` 153 print self.print_prefix + "pj+1: " + `self.pj_new` 154 155 # Second test. 156 if sqrt(dot(self.pj_new, self.pj_new)) >= self.delta: 157 tau = self.get_tau() 158 if self.print_flag >= 2: 159 print self.print_prefix + "sqrt(dot(self.pj_new, self.pj_new)) >= self.delta, therefore tau = " + `tau` 160 return self.pj + tau * self.dj 161 162 self.rj_new = self.rj + aj * dot(self.B, self.dj) 163 if self.print_flag >= 2: 164 print self.print_prefix + "rj+1: " + `self.rj_new` 165 166 # Third test. 167 if sqrt(dot(self.rj_new, self.rj_new)) < length_test: 168 if self.print_flag >= 2: 169 print self.print_prefix + "sqrt(dot(self.rj_new, self.rj_new)) < length_test" 170 return self.pj_new 171 172 bj_new = dot(self.rj_new, self.rj_new) / dot(self.rj, self.rj) 173 self.dj_new = -self.rj_new + bj_new * self.dj 174 if self.print_flag >= 2: 175 print self.print_prefix + "len rj+1: " + `sqrt(dot(self.rj_new, self.rj_new))` 176 print self.print_prefix + "epsilon.||r0||: " + `length_test` 177 print self.print_prefix + "bj+1: " + `bj_new` 178 print self.print_prefix + "dj+1: " + `self.dj_new` 179 180 # Update j+1 to j. 181 self.pj = self.pj_new * 1.0 182 self.rj = self.rj_new * 1.0 183 self.dj = self.dj_new * 1.0 184 #if j > 2: 185 # import sys 186 # sys.exit() 187 j = j + 1
188 189
190 - def get_tau(self):
191 """Function to find tau such that p = pj + tau.dj, and ||p|| = delta.""" 192 193 dot_pj_dj = dot(self.pj, self.dj) 194 len_dj_sqrd = dot(self.dj, self.dj) 195 196 tau = -dot_pj_dj + sqrt(dot_pj_dj**2 - len_dj_sqrd * (dot(self.pj, self.pj) - self.delta**2)) / len_dj_sqrd 197 return tau
198 199
200 - def new_param_func(self):
201 """Find the CG-Steihaug minimiser.""" 202 203 # Get the pk vector. 204 self.pk = self.get_pk() 205 206 # Find the new parameter vector and function value at that point. 207 self.xk_new = self.xk + self.pk 208 self.fk_new, self.f_count = apply(self.func, (self.xk_new,)+self.args), self.f_count + 1 209 self.dfk_new, self.g_count = apply(self.dfunc, (self.xk_new,)+self.args), self.g_count + 1
210 211
212 - def update(self):
213 """Update function. 214 215 Run the trust region update. If this update decides to shift xk+1 to xk, then run the 216 Newton update. 217 """ 218 219 self.trust_region_update() 220 if self.shift_flag: 221 self.specific_update()
222