1   
  2   
  3   
  4   
  5   
  6   
  7   
  8   
  9   
 10   
 11   
 12   
 13   
 14   
 15   
 16   
 17   
 18   
 19   
 20   
 21   
 22   
 23   
 24  """Steihaug conjugate-gradient trust region optimization. 
 25   
 26  This file is part of the minfx optimisation library at U{https://sourceforge.net/projects/minfx}. 
 27  """ 
 28   
 29   
 30  from numpy import dot, float64, sqrt, zeros 
 31   
 32   
 33  from minfx.base_classes import Min, Trust_region 
 34  from minfx.newton import Newton 
 35   
 36   
 37 -def steihaug(func=None, dfunc=None, d2func=None, args=(), x0=None, func_tol=1e-25, grad_tol=None, maxiter=1e6, epsilon=1e-8, delta_max=1e5, delta0=1.0, eta=0.2, full_output=0, print_flag=0, print_prefix=""): 
  38      """Steihaug conjugate-gradient trust region algorithm. 
 39   
 40      Page 75 from 'Numerical Optimization' by Jorge Nocedal and Stephen J. Wright, 1999, 2nd ed.  The CG-Steihaug algorithm is: 
 41   
 42          - epsilon > 0 
 43          - p0 = 0, r0 = g, d0 = -r0 
 44          - if ||r0|| < epsilon: 
 45              - return p = p0 
 46          - while 1: 
 47              - if djT.B.dj <= 0: 
 48                  - Find tau such that p = pj + tau.dj minimises m(p) in (4.9) and satisfies ||p|| = delta 
 49                  - return p 
 50              - aj = rjT.rj / djT.B.dj 
 51              - pj+1 = pj + aj.dj 
 52              - if ||pj+1|| >= delta: 
 53                  - Find tau such that p = pj + tau.dj satisfies ||p|| = delta 
 54                  - return p 
 55              - rj+1 = rj + aj.B.dj 
 56              - if ||rj+1|| < epsilon.||r0||: 
 57                  - return p = pj+1 
 58              - bj+1 = rj+1T.rj+1 / rjT.rj 
 59              - dj+1 = rj+1 + bj+1.dj 
 60      """ 
 61   
 62      if print_flag: 
 63          if print_flag >= 2: 
 64              print(print_prefix) 
 65          print(print_prefix) 
 66          print(print_prefix + "CG-Steihaug minimisation") 
 67          print(print_prefix + "~~~~~~~~~~~~~~~~~~~~~~~~") 
 68      min = Steihaug(func, dfunc, d2func, args, x0, func_tol, grad_tol, maxiter, epsilon, delta_max, delta0, eta, full_output, print_flag, print_prefix) 
 69      results = min.minimise() 
 70      return results 
  71   
 72   
 74 -    def __init__(self, func, dfunc, d2func, args, x0, func_tol, grad_tol, maxiter, epsilon, delta_max, delta0, eta, full_output, print_flag, print_prefix): 
  75          """Class for Steihaug conjugate-gradient trust region minimisation specific functions. 
 76   
 77          Unless you know what you are doing, you should call the function 'steihaug' rather than using this class. 
 78          """ 
 79   
 80           
 81          self.func = func 
 82          self.dfunc = dfunc 
 83          self.d2func = d2func 
 84          self.args = args 
 85          self.xk = x0 
 86          self.func_tol = func_tol 
 87          self.grad_tol = grad_tol 
 88          self.maxiter = maxiter 
 89          self.full_output = full_output 
 90          self.print_flag = print_flag 
 91          self.print_prefix = print_prefix 
 92          self.epsilon = epsilon 
 93          self.delta_max = delta_max 
 94          self.delta = delta0 
 95          self.eta = eta 
 96   
 97           
 98          self.f_count = 0 
 99          self.g_count = 0 
100          self.h_count = 0 
101   
102           
103          self.warning = None 
104   
105           
106          self.setup_conv_tests() 
107   
108           
109          self.setup_newton() 
110   
111           
112          self.specific_update = self.update_newton 
 113   
114   
116          """The CG-Steihaug algorithm.""" 
117   
118           
119          self.pj = zeros(len(self.xk), float64) 
120          self.rj = self.dfk * 1.0 
121          self.dj = -self.dfk * 1.0 
122          self.B = self.d2fk * 1.0 
123          len_r0 = sqrt(dot(self.rj, self.rj)) 
124          length_test = self.epsilon * len_r0 
125   
126          if self.print_flag >= 2: 
127              print(self.print_prefix + "p0: " + repr(self.pj)) 
128              print(self.print_prefix + "r0: " + repr(self.rj)) 
129              print(self.print_prefix + "d0: " + repr(self.dj)) 
130   
131          if len_r0 < self.epsilon: 
132              if self.print_flag >= 2: 
133                  print(self.print_prefix + "len rj < epsilon.") 
134              return self.pj 
135   
136           
137          j = 0 
138          while True: 
139               
140              curv = dot(self.dj, dot(self.B, self.dj)) 
141              if self.print_flag >= 2: 
142                  print(self.print_prefix + "\nIteration j = " + repr(j)) 
143                  print(self.print_prefix + "Curv: " + repr(curv)) 
144   
145               
146              if curv <= 0.0: 
147                  tau = self.get_tau() 
148                  if self.print_flag >= 2: 
149                      print(self.print_prefix + "curv <= 0.0, therefore tau = " + repr(tau)) 
150                  return self.pj + tau * self.dj 
151   
152              aj = dot(self.rj, self.rj) / curv 
153              self.pj_new = self.pj + aj * self.dj 
154              if self.print_flag >= 2: 
155                  print(self.print_prefix + "aj: " + repr(aj)) 
156                  print(self.print_prefix + "pj+1: " + repr(self.pj_new)) 
157   
158               
159              if sqrt(dot(self.pj_new, self.pj_new)) >= self.delta: 
160                  tau = self.get_tau() 
161                  if self.print_flag >= 2: 
162                      print(self.print_prefix + "sqrt(dot(self.pj_new, self.pj_new)) >= self.delta, therefore tau = " + repr(tau)) 
163                  return self.pj + tau * self.dj 
164   
165              self.rj_new = self.rj + aj * dot(self.B, self.dj) 
166              if self.print_flag >= 2: 
167                  print(self.print_prefix + "rj+1: " + repr(self.rj_new)) 
168   
169               
170              if sqrt(dot(self.rj_new, self.rj_new)) < length_test: 
171                  if self.print_flag >= 2: 
172                      print(self.print_prefix + "sqrt(dot(self.rj_new, self.rj_new)) < length_test") 
173                  return self.pj_new 
174   
175              bj_new = dot(self.rj_new, self.rj_new) / dot(self.rj, self.rj) 
176              self.dj_new = -self.rj_new + bj_new * self.dj 
177              if self.print_flag >= 2: 
178                  print(self.print_prefix + "len rj+1: " + repr(sqrt(dot(self.rj_new, self.rj_new)))) 
179                  print(self.print_prefix + "epsilon.||r0||: " + repr(length_test)) 
180                  print(self.print_prefix + "bj+1: " + repr(bj_new)) 
181                  print(self.print_prefix + "dj+1: " + repr(self.dj_new)) 
182   
183               
184              self.pj = self.pj_new * 1.0 
185              self.rj = self.rj_new * 1.0 
186              self.dj = self.dj_new * 1.0 
187               
188               
189               
190              j = j + 1 
 191   
192   
194          """Function to find tau such that p = pj + tau.dj, and ||p|| = delta.""" 
195   
196          dot_pj_dj = dot(self.pj, self.dj) 
197          len_dj_sqrd = dot(self.dj, self.dj) 
198   
199          tau = -dot_pj_dj + sqrt(dot_pj_dj**2 - len_dj_sqrd * (dot(self.pj, self.pj) - self.delta**2)) / len_dj_sqrd 
200          return tau 
 201   
202   
204          """Find the CG-Steihaug minimiser.""" 
205   
206           
207          self.pk = self.get_pk() 
208   
209           
210          self.xk_new = self.xk + self.pk 
211          self.fk_new, self.f_count = self.func(*(self.xk_new,)+self.args), self.f_count + 1 
212          self.dfk_new, self.g_count = self.dfunc(*(self.xk_new,)+self.args), self.g_count + 1 
 213   
214   
216          """Update function. 
217   
218          Run the trust region update.  If this update decides to shift xk+1 to xk, then run the Newton update. 
219          """ 
220   
221          self.trust_region_update() 
222          if self.shift_flag: 
223              self.specific_update()