Package minfx :: Module steihaug_cg
[hide private]
[frames] | no frames]

Source Code for Module minfx.steihaug_cg

  1  ############################################################################### 
  2  #                                                                             # 
  3  # Copyright (C) 2003-2013 Edward d'Auvergne                                   # 
  4  #                                                                             # 
  5  # This file is part of the minfx optimisation library,                        # 
  6  # https://sourceforge.net/projects/minfx                                      # 
  7  #                                                                             # 
  8  # This program is free software: you can redistribute it and/or modify        # 
  9  # it under the terms of the GNU General Public License as published by        # 
 10  # the Free Software Foundation, either version 3 of the License, or           # 
 11  # (at your option) any later version.                                         # 
 12  #                                                                             # 
 13  # This program is distributed in the hope that it will be useful,             # 
 14  # but WITHOUT ANY WARRANTY; without even the implied warranty of              # 
 15  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               # 
 16  # GNU General Public License for more details.                                # 
 17  #                                                                             # 
 18  # You should have received a copy of the GNU General Public License           # 
 19  # along with this program.  If not, see <http://www.gnu.org/licenses/>.       # 
 20  #                                                                             # 
 21  ############################################################################### 
 22   
 23  # Module docstring. 
 24  """Steihaug conjugate-gradient trust region optimization. 
 25   
 26  This file is part of the minfx optimisation library at U{https://sourceforge.net/projects/minfx}. 
 27  """ 
 28   
 29  # Python module imports. 
 30  from numpy import dot, float64, sqrt, zeros 
 31   
 32  # Minfx module imports. 
 33  from minfx.base_classes import Min, Trust_region 
 34  from minfx.newton import Newton 
 35   
 36   
37 -def steihaug(func=None, dfunc=None, d2func=None, args=(), x0=None, func_tol=1e-25, grad_tol=None, maxiter=1e6, epsilon=1e-8, delta_max=1e5, delta0=1.0, eta=0.2, full_output=0, print_flag=0, print_prefix=""):
38 """Steihaug conjugate-gradient trust region algorithm. 39 40 Page 75 from 'Numerical Optimization' by Jorge Nocedal and Stephen J. Wright, 1999, 2nd ed. The CG-Steihaug algorithm is: 41 42 - epsilon > 0 43 - p0 = 0, r0 = g, d0 = -r0 44 - if ||r0|| < epsilon: 45 - return p = p0 46 - while 1: 47 - if djT.B.dj <= 0: 48 - Find tau such that p = pj + tau.dj minimises m(p) in (4.9) and satisfies ||p|| = delta 49 - return p 50 - aj = rjT.rj / djT.B.dj 51 - pj+1 = pj + aj.dj 52 - if ||pj+1|| >= delta: 53 - Find tau such that p = pj + tau.dj satisfies ||p|| = delta 54 - return p 55 - rj+1 = rj + aj.B.dj 56 - if ||rj+1|| < epsilon.||r0||: 57 - return p = pj+1 58 - bj+1 = rj+1T.rj+1 / rjT.rj 59 - dj+1 = rj+1 + bj+1.dj 60 """ 61 62 if print_flag: 63 if print_flag >= 2: 64 print(print_prefix) 65 print(print_prefix) 66 print(print_prefix + "CG-Steihaug minimisation") 67 print(print_prefix + "~~~~~~~~~~~~~~~~~~~~~~~~") 68 min = Steihaug(func, dfunc, d2func, args, x0, func_tol, grad_tol, maxiter, epsilon, delta_max, delta0, eta, full_output, print_flag, print_prefix) 69 results = min.minimise() 70 return results
71 72
73 -class Steihaug(Trust_region, Newton):
74 - def __init__(self, func, dfunc, d2func, args, x0, func_tol, grad_tol, maxiter, epsilon, delta_max, delta0, eta, full_output, print_flag, print_prefix):
75 """Class for Steihaug conjugate-gradient trust region minimisation specific functions. 76 77 Unless you know what you are doing, you should call the function 'steihaug' rather than using this class. 78 """ 79 80 # Function arguments. 81 self.func = func 82 self.dfunc = dfunc 83 self.d2func = d2func 84 self.args = args 85 self.xk = x0 86 self.func_tol = func_tol 87 self.grad_tol = grad_tol 88 self.maxiter = maxiter 89 self.full_output = full_output 90 self.print_flag = print_flag 91 self.print_prefix = print_prefix 92 self.epsilon = epsilon 93 self.delta_max = delta_max 94 self.delta = delta0 95 self.eta = eta 96 97 # Initialise the function, gradient, and Hessian evaluation counters. 98 self.f_count = 0 99 self.g_count = 0 100 self.h_count = 0 101 102 # Initialise the warning string. 103 self.warning = None 104 105 # Set the convergence test function. 106 self.setup_conv_tests() 107 108 # Newton setup function. 109 self.setup_newton() 110 111 # Set the update function. 112 self.specific_update = self.update_newton
113 114
115 - def get_pk(self):
116 """The CG-Steihaug algorithm.""" 117 118 # Initial values at j = 0. 119 self.pj = zeros(len(self.xk), float64) 120 self.rj = self.dfk * 1.0 121 self.dj = -self.dfk * 1.0 122 self.B = self.d2fk * 1.0 123 len_r0 = sqrt(dot(self.rj, self.rj)) 124 length_test = self.epsilon * len_r0 125 126 if self.print_flag >= 2: 127 print(self.print_prefix + "p0: " + repr(self.pj)) 128 print(self.print_prefix + "r0: " + repr(self.rj)) 129 print(self.print_prefix + "d0: " + repr(self.dj)) 130 131 if len_r0 < self.epsilon: 132 if self.print_flag >= 2: 133 print(self.print_prefix + "len rj < epsilon.") 134 return self.pj 135 136 # Iterate over j. 137 j = 0 138 while True: 139 # The curvature. 140 curv = dot(self.dj, dot(self.B, self.dj)) 141 if self.print_flag >= 2: 142 print(self.print_prefix + "\nIteration j = " + repr(j)) 143 print(self.print_prefix + "Curv: " + repr(curv)) 144 145 # First test. 146 if curv <= 0.0: 147 tau = self.get_tau() 148 if self.print_flag >= 2: 149 print(self.print_prefix + "curv <= 0.0, therefore tau = " + repr(tau)) 150 return self.pj + tau * self.dj 151 152 aj = dot(self.rj, self.rj) / curv 153 self.pj_new = self.pj + aj * self.dj 154 if self.print_flag >= 2: 155 print(self.print_prefix + "aj: " + repr(aj)) 156 print(self.print_prefix + "pj+1: " + repr(self.pj_new)) 157 158 # Second test. 159 if sqrt(dot(self.pj_new, self.pj_new)) >= self.delta: 160 tau = self.get_tau() 161 if self.print_flag >= 2: 162 print(self.print_prefix + "sqrt(dot(self.pj_new, self.pj_new)) >= self.delta, therefore tau = " + repr(tau)) 163 return self.pj + tau * self.dj 164 165 self.rj_new = self.rj + aj * dot(self.B, self.dj) 166 if self.print_flag >= 2: 167 print(self.print_prefix + "rj+1: " + repr(self.rj_new)) 168 169 # Third test. 170 if sqrt(dot(self.rj_new, self.rj_new)) < length_test: 171 if self.print_flag >= 2: 172 print(self.print_prefix + "sqrt(dot(self.rj_new, self.rj_new)) < length_test") 173 return self.pj_new 174 175 bj_new = dot(self.rj_new, self.rj_new) / dot(self.rj, self.rj) 176 self.dj_new = -self.rj_new + bj_new * self.dj 177 if self.print_flag >= 2: 178 print(self.print_prefix + "len rj+1: " + repr(sqrt(dot(self.rj_new, self.rj_new)))) 179 print(self.print_prefix + "epsilon.||r0||: " + repr(length_test)) 180 print(self.print_prefix + "bj+1: " + repr(bj_new)) 181 print(self.print_prefix + "dj+1: " + repr(self.dj_new)) 182 183 # Update j+1 to j. 184 self.pj = self.pj_new * 1.0 185 self.rj = self.rj_new * 1.0 186 self.dj = self.dj_new * 1.0 187 #if j > 2: 188 # import sys 189 # sys.exit() 190 j = j + 1
191 192
193 - def get_tau(self):
194 """Function to find tau such that p = pj + tau.dj, and ||p|| = delta.""" 195 196 dot_pj_dj = dot(self.pj, self.dj) 197 len_dj_sqrd = dot(self.dj, self.dj) 198 199 tau = -dot_pj_dj + sqrt(dot_pj_dj**2 - len_dj_sqrd * (dot(self.pj, self.pj) - self.delta**2)) / len_dj_sqrd 200 return tau
201 202
203 - def new_param_func(self):
204 """Find the CG-Steihaug minimiser.""" 205 206 # Get the pk vector. 207 self.pk = self.get_pk() 208 209 # Find the new parameter vector and function value at that point. 210 self.xk_new = self.xk + self.pk 211 self.fk_new, self.f_count = self.func(*(self.xk_new,)+self.args), self.f_count + 1 212 self.dfk_new, self.g_count = self.dfunc(*(self.xk_new,)+self.args), self.g_count + 1
213 214
215 - def update(self):
216 """Update function. 217 218 Run the trust region update. If this update decides to shift xk+1 to xk, then run the Newton update. 219 """ 220 221 self.trust_region_update() 222 if self.shift_flag: 223 self.specific_update()
224