Package minimise :: Package line_search :: Module more_thuente
[hide private]
[frames] | no frames]

Source Code for Module minimise.line_search.more_thuente

  1  import sys 
  2  from copy import deepcopy 
  3  from math import sqrt 
  4  from Numeric import copy, dot 
  5   
  6  from interpolate import cubic_int, cubic_ext, quadratic_fafbga, quadratic_gagb 
  7   
  8  cubic = cubic_int 
  9  quadratic = quadratic_fafbga 
 10  secant = quadratic_gagb 
 11   
 12   
13 -def more_thuente(func, func_prime, args, x, f, g, p, a_init=1.0, a_min=None, a_max=None, a_tol=1e-10, phi_min=-1e3, mu=0.001, eta=0.1, print_flag=0):
14 """A line search algorithm from More and Thuente. 15 16 More, J. J., and Thuente, D. J. 1994, Line search algorithms with guaranteed sufficient decrease. 17 ACM Trans. Math. Softw. 20, 286-307. 18 19 20 Internal variables 21 ~~~~~~~~~~~~~~~~~~ 22 23 a0, the null sequence data structure containing the following keys: 24 'a' - 0 25 'phi' - phi(0) 26 'phi_prime' - phi'(0) 27 28 a, the sequence data structure containing the following keys: 29 'a' - alpha 30 'phi' - phi(alpha) 31 'phi_prime' - phi'(alpha) 32 33 Ik, the interval data structure containing the following keys: 34 'a' - The current interval Ik = [al, au] 35 'phi' - The interval [phi(al), phi(au)] 36 'phi_prime' - The interval [phi'(al), phi'(au)] 37 38 Instead of using the modified function: 39 psi(a) = phi(a) - phi(0) - a.phi'(0), 40 the function: 41 psi(a) = phi(a) - a.phi'(0), 42 was used as the phi(0) component has no effect on the results. 43 """ 44 45 # Initialise values. 46 k = 0 47 f_count = 0 48 g_count = 0 49 mod_flag = 1 50 bracketed = 0 51 a0 = {} 52 a0['a'] = 0.0 53 a0['phi'] = f 54 a0['phi_prime'] = dot(g, p) 55 if not a_min: 56 a_min = 0.0 57 if not a_max: 58 a_max = 4.0*max(1.0,a_init) 59 Ik_lim = [0.0, 5.0*a_init] 60 width = a_max - a_min 61 width2 = 2.0*width 62 63 # Initialise sequence data. 64 a = {} 65 a['a'] = a_init 66 a['phi'] = apply(func, (x + a['a']*p,)+args) 67 a['phi_prime'] = dot(apply(func_prime, (x + a['a']*p,)+args), p) 68 f_count = f_count + 1 69 g_count = g_count + 1 70 71 # Initialise interval data. 72 Ik = {} 73 Ik['a'] = [0.0, 0.0] 74 Ik['phi'] = [a0['phi'], a0['phi']] 75 Ik['phi_prime'] = [a0['phi_prime'], a0['phi_prime']] 76 77 if print_flag: 78 print "\n<Line search initial values>" 79 print_data("Pre", -1, a0, Ik, Ik_lim) 80 81 # Test for errors. 82 if a0['phi_prime'] > 0.0: 83 print "self.xk = " + `x` 84 print "self.fk = " + `f` 85 print "self.dfk = " + `g` 86 print "self.pk = " + `p` 87 print "dot(self.dfk, self.pk) = " + `dot(g, p)` 88 print "a0['phi_prime'] = " + `a0['phi_prime']` 89 raise NameError, "The gradient at point 0 of this line search is positive, ie p is not a descent direction and the line search will not work." 90 if a['a'] < a_min: 91 raise NameError, "Alpha is less than alpha_min, " + `a['a']` + " > " + `a_min` 92 if a['a'] > a_max: 93 raise NameError, "Alpha is greater than alpha_max, " + `a['a']` + " > " + `a_max` 94 95 while 1: 96 if print_flag: 97 print "\n<Line search iteration k = " + `k+1` + " >" 98 print "Bracketed: " + `bracketed` 99 print_data("Initial", k, a, Ik, Ik_lim) 100 101 # Test values. 102 curv = mu * a0['phi_prime'] 103 suff_dec = a0['phi'] + a['a'] * curv 104 105 # Modification flag, 0 - phi, 1 - psi. 106 if mod_flag: 107 if a['phi'] <= suff_dec and a['phi_prime'] >= 0.0: 108 mod_flag = 0 109 110 # Test for convergence using the strong Wolfe conditions. 111 if print_flag: 112 print "Testing for convergence using the strong Wolfe conditions." 113 if a['phi'] <= suff_dec and abs(a['phi_prime']) <= eta * abs(a0['phi_prime']): 114 if print_flag: 115 print "\tYes." 116 print "<Line search has converged>\n" 117 return a['a'], f_count, g_count 118 if print_flag: 119 print "\tNo." 120 121 # Test if limits have been reached. 122 if print_flag: 123 print "Testing if limits have been reached." 124 if a['a'] == a_min: 125 if a['phi'] > suff_dec or a['phi_prime'] >= curv: 126 if print_flag: 127 print "\tYes." 128 print "<Min alpha has been reached>\n" 129 return a['a'], f_count, g_count 130 if a['a'] == a_max: 131 if a['phi'] <= suff_dec and a['phi_prime'] <= curv: 132 if print_flag: 133 print "\tYes." 134 print "<Max alpha has been reached>\n" 135 return a['a'], f_count, g_count 136 if print_flag: 137 print "\tNo." 138 139 if bracketed: 140 # Test for roundoff error. 141 if print_flag: 142 print "Testing for roundoff error." 143 if a['a'] <= Ik_lim[0] or a['a'] >= Ik_lim[1]: 144 if print_flag: 145 print "\tYes." 146 print "<Stopping due to roundoff error>\n" 147 return a['a'], f_count, g_count 148 if print_flag: 149 print "\tNo." 150 151 # Test to see if a_tol has been reached. 152 if print_flag: 153 print "Testing tol." 154 if Ik_lim[1] - Ik_lim[0] <= a_tol * Ik_lim[1]: 155 if print_flag: 156 print "\tYes." 157 print "<Stopping tol>\n" 158 return a['a'], f_count, g_count 159 if print_flag: 160 print "\tNo." 161 162 # Choose a safeguarded ak in set Ik which is a subset of [a_min, a_max], and update the interval Ik. 163 a_new = {} 164 if mod_flag and a['phi'] <= Ik['phi'][0] and a['phi'] > suff_dec: 165 if print_flag: 166 print "Choosing ak and updating the interval Ik using the modified function psi." 167 168 # Calculate the modified function values and gradients at at, al, and au. 169 psi = a['phi'] - curv * a['a'] 170 psi_l = Ik['phi'][0] - curv * Ik['a'][0] 171 psi_u = Ik['phi'][1] - curv * Ik['a'][1] 172 psi_prime = a['phi_prime'] - curv 173 psi_l_prime = Ik['phi_prime'][0] - curv 174 psi_u_prime = Ik['phi_prime'][1] - curv 175 176 a_new['a'], Ik_new, bracketed = update(a, Ik, a['a'], Ik['a'][0], Ik['a'][1], psi, psi_l, psi_u, psi_prime, psi_l_prime, psi_u_prime, bracketed, Ik_lim, print_flag=print_flag) 177 else: 178 if print_flag: 179 print "Choosing ak and updating the interval Ik using the function phi." 180 a_new['a'], Ik_new, bracketed = update(a, Ik, a['a'], Ik['a'][0], Ik['a'][1], a['phi'], Ik['phi'][0], Ik['phi'][1], a['phi_prime'], Ik['phi_prime'][0], Ik['phi_prime'][1], bracketed, Ik_lim, print_flag=print_flag) 181 182 # Bisection step. 183 if bracketed: 184 size = abs(Ik_new['a'][0] - Ik_new['a'][1]) 185 if size >= 0.66 * width2: 186 if print_flag: 187 print "Bisection step." 188 a_new['a'] = 0.5 * (Ik_new['a'][0] + Ik_new['a'][1]) 189 width2 = width 190 width = size 191 192 # Limit. 193 if print_flag: 194 print "Limiting" 195 print " Ik_lim: " + `Ik_lim` 196 if bracketed: 197 if print_flag: 198 print " Bracketed." 199 Ik_lim[0] = min(Ik_new['a'][0], Ik_new['a'][1]) 200 Ik_lim[1] = max(Ik_new['a'][0], Ik_new['a'][1]) 201 else: 202 if print_flag: 203 print " Not bracketed." 204 print " a_new['a']: " + `a_new['a']` 205 print " xtrapl: " + `1.1` 206 Ik_lim[0] = a_new['a'] + 1.1 * (a_new['a'] - Ik_new['a'][0]) 207 Ik_lim[1] = a_new['a'] + 4.0 * (a_new['a'] - Ik_new['a'][0]) 208 if print_flag: 209 print " Ik_lim: " + `Ik_lim` 210 211 # The step must be between a_min and a_max. 212 if a_new['a'] < a_min: 213 if print_flag: 214 print "The step is below a_min, therefore setting the step length to a_min." 215 a_new['a'] = a_min 216 if a_new['a'] > a_max: 217 if print_flag: 218 print "The step is above a_max, therefore setting the step length to a_max." 219 a_new['a'] = a_max 220 221 if bracketed: 222 if a_new['a'] <= Ik_lim[0] or a_new['a'] >= Ik_lim[1] or Ik_lim[1] - Ik_lim[0] <= a_tol * Ik_lim[1]: 223 if print_flag: 224 print "aaa" 225 a_new['a'] = Ik['a'][0] 226 227 # Calculate new values. 228 if print_flag: 229 print "Calculating new values." 230 a_new['phi'] = apply(func, (x + a_new['a']*p,)+args) 231 a_new['phi_prime'] = dot(apply(func_prime, (x + a_new['a']*p,)+args), p) 232 f_count = f_count + 1 233 g_count = g_count + 1 234 235 # Shift data from k+1 to k. 236 k = k + 1 237 if print_flag: 238 print "Bracketed: " + `bracketed` 239 print_data("Final", k, a_new, Ik_new, Ik_lim) 240 a = deepcopy(a_new) 241 Ik = deepcopy(Ik_new)
242 243 244 257 258
259 -def update(a, Ik, at, al, au, ft, fl, fu, gt, gl, gu, bracketed, Ik_lim, d=0.66, print_flag=0):
260 """Trial value selection and interval updating. 261 262 Trial value selection 263 ~~~~~~~~~~~~~~~~~~~~~ 264 265 fl, fu, ft, gl, gu, and gt are the function and gradient values at the interval end points al and au, and at the trial point at. 266 ac is the minimiser of the cubic that interpolates fl, ft, gl, and gt. 267 aq is the minimiser of the quadratic that interpolates fl, ft, and gl. 268 as is the minimiser of the quadratic that interpolates fl, gl, and gt. 269 270 The trial value selection is divided into four cases. 271 272 Case 1: ft > fl. In this case compute ac and aq, and set 273 274 / ac, if |ac - al| < |aq - al|, 275 at+ = < 276 \ 1/2(aq + ac), otherwise. 277 278 279 Case 2: ft <= fl and gt.gl < 0. In this case compute ac and as, and set 280 281 / ac, if |ac - at| >= |as - at|, 282 at+ = < 283 \ as, otherwise. 284 285 286 Case 3: ft <= fl and gt.gl >= 0, and |gt| <= |gl|. In this case at+ is chosen by extrapolating the function values at al and at, so the trial value at+ 287 lies outside th interval with at and al as endpoints. Compute ac and as. 288 289 If the cubic tends to infinity in the direction of the step and the minimum of the cubic is beyound at, set 290 291 / ac, if |ac - at| < |as - at|, 292 at+ = < 293 \ as, otherwise. 294 295 Otherwise set at+ = as. 296 297 298 Redefine at+ by setting 299 300 / min{at + d(au - at), at+}, if at > al. 301 at+ = < 302 \ max{at + d(au - at), at+}, otherwise, 303 304 for some d < 1. 305 306 307 Case 4: ft <= fl and gt.gl >= 0, and |gt| > |gl|. In this case choose at+ as the minimiser of the cubic that interpolates fu, ft, gu, and gt. 308 309 310 Interval updating 311 ~~~~~~~~~~~~~~~~~ 312 313 Given a trial value at in I, the endpoints al+ and au+ of the updated interval I+ are determined as follows: 314 Case U1: If f(at) > f(al), then al+ = al and au+ = at. 315 Case U2: If f(at) <= f(al) and f'(at)(al - at) > 0, then al+ = at and au+ = au. 316 Case U3: If f(at) <= f(al) and f'(at)(al - at) < 0, then al+ = at and au+ = al. 317 """ 318 319 # Trial value selection. 320 321 # Case 1. 322 if ft > fl: 323 if print_flag: 324 print "\tat selection, case 1." 325 # The minimum is bracketed. 326 bracketed = 1 327 328 # Interpolation. 329 ac = cubic(al, at, fl, ft, gl, gt) 330 aq = quadratic(al, at, fl, ft, gl) 331 if print_flag: 332 print "\t\tac: " + `ac` 333 print "\t\taq: " + `aq` 334 335 # Return at+. 336 if abs(ac - al) < abs(aq - al): 337 if print_flag: 338 print "\t\tabs(ac - al) < abs(aq - al), " + `abs(ac - al)` + " < " + `abs(aq - al)` 339 print "\t\tat_new = ac = " + `ac` 340 at_new = ac 341 else: 342 if print_flag: 343 print "\t\tabs(ac - al) >= abs(aq - al), " + `abs(ac - al)` + " >= " + `abs(aq - al)` 344 print "\t\tat_new = 1/2(aq + ac) = " + `0.5*(aq + ac)` 345 at_new = 0.5*(aq + ac) 346 347 348 # Case 2. 349 elif gt * gl < 0.0: 350 if print_flag: 351 print "\tat selection, case 2." 352 # The minimum is bracketed. 353 bracketed = 1 354 355 # Interpolation. 356 ac = cubic(al, at, fl, ft, gl, gt) 357 as = secant(al, at, gl, gt) 358 if print_flag: 359 print "\t\tac: " + `ac` 360 print "\t\tas: " + `as` 361 362 # Return at+. 363 if abs(ac - at) >= abs(as - at): 364 if print_flag: 365 print "\t\tabs(ac - at) >= abs(as - at), " + `abs(ac - at)` + " >= " + `abs(as - at)` 366 print "\t\tat_new = ac = " + `ac` 367 at_new = ac 368 else: 369 if print_flag: 370 print "\t\tabs(ac - at) < abs(as - at), " + `abs(ac - at)` + " < " + `abs(as - at)` 371 print "\t\tat_new = as = " + `as` 372 at_new = as 373 374 375 # Case 3. 376 elif abs(gt) <= abs(gl): 377 if print_flag: 378 print "\tat selection, case 3." 379 380 # Interpolation. 381 ac, beta1, beta2 = cubic_ext(al, at, fl, ft, gl, gt, full_output=1) 382 383 if ac > at and beta2 != 0.0: 384 # Leave ac as ac. 385 if print_flag: 386 print "\t\tac > at and beta2 != 0.0" 387 elif at > al: 388 # Set ac to the upper limit. 389 if print_flag: 390 print "\t\tat > al, " + `at` + " > " + `al` 391 ac = Ik_lim[1] 392 else: 393 # Set ac to the lower limit. 394 ac = Ik_lim[0] 395 396 as = secant(al, at, gl, gt) 397 398 if print_flag: 399 print "\t\tac: " + `ac` 400 print "\t\tas: " + `as` 401 402 # Test if bracketed. 403 if bracketed: 404 if print_flag: 405 print "\t\tBracketed" 406 if abs(ac - at) < abs(as - at): 407 if print_flag: 408 print "\t\t\tabs(ac - at) < abs(as - at), " + `abs(ac - at)` + " < " + `abs(as - at)` 409 print "\t\t\tat_new = ac = " + `ac` 410 at_new = ac 411 else: 412 if print_flag: 413 print "\t\t\tabs(ac - at) >= abs(as - at), " + `abs(ac - at)` + " >= " + `abs(as - at)` 414 print "\t\t\tat_new = as = " + `as` 415 at_new = as 416 417 # Redefine at+. 418 if print_flag: 419 print "\t\tRedefining at+" 420 if at > al: 421 at_new = min(at + d*(au - at), at_new) 422 if print_flag: 423 print "\t\t\tat > al, " + `at` + " > " + `al` 424 print "\t\t\tat_new = " + `at_new` 425 else: 426 at_new = max(at + d*(au - at), at_new) 427 if print_flag: 428 print "\t\t\tat <= al, " + `at` + " <= " + `al` 429 print "\t\t\tat_new = " + `at_new` 430 else: 431 if print_flag: 432 print "\t\tNot bracketed" 433 if abs(ac - at) > abs(as - at): 434 if print_flag: 435 print "\t\t\tabs(ac - at) > abs(as - at), " + `abs(ac - at)` + " > " + `abs(as - at)` 436 print "\t\t\tat_new = ac = " + `ac` 437 at_new = ac 438 else: 439 if print_flag: 440 print "\t\t\tabs(ac - at) <= abs(as - at), " + `abs(ac - at)` + " <= " + `abs(as - at)` 441 print "\t\t\tat_new = as = " + `as` 442 at_new = as 443 444 # Check limits. 445 if print_flag: 446 print "\t\tChecking limits." 447 if at_new < Ik_lim[0]: 448 if print_flag: 449 print "\t\t\tat_new < Ik_lim[0], " + `at_new` + " < " + `Ik_lim[0]` 450 print "\t\t\tat_new = " + `Ik_lim[0]` 451 at_new = Ik_lim[0] 452 if at_new > Ik_lim[1]: 453 if print_flag: 454 print "\t\t\tat_new > Ik_lim[1], " + `at_new` + " > " + `Ik_lim[1]` 455 print "\t\t\tat_new = " + `Ik_lim[1]` 456 at_new = Ik_lim[1] 457 458 459 # Case 4. 460 else: 461 if print_flag: 462 print "\tat selection, case 4." 463 if bracketed: 464 if print_flag: 465 print "\t\tbracketed." 466 at_new = cubic(au, at, fu, ft, gu, gt) 467 if print_flag: 468 print "\t\tat_new = " + `at_new` 469 elif at > al: 470 if print_flag: 471 print "\t\tnot bracketed but at > al, " + `at` + " > " + `al` 472 print "\t\tat_new = " + `Ik_lim[1]` 473 at_new = Ik_lim[1] 474 else: 475 if print_flag: 476 print "\t\tnot bracketed but at <= al, " + `at` + " <= " + `al` 477 print "\t\tat_new = " + `Ik_lim[0]` 478 at_new = Ik_lim[0] 479 480 481 # Interval updating algorithm. 482 Ik_new = deepcopy(Ik) 483 484 if ft > fl: 485 if print_flag: 486 print "\tIk update, case a, ft > fl." 487 Ik_new['a'][1] = at 488 Ik_new['phi'][1] = a['phi'] 489 Ik_new['phi_prime'][1] = a['phi_prime'] 490 elif gt*(al - at) > 0.0: 491 if print_flag: 492 print "\tIk update, case b, gt*(al - at) > 0.0." 493 Ik_new['a'][0] = at 494 Ik_new['phi'][0] = a['phi'] 495 Ik_new['phi_prime'][0] = a['phi_prime'] 496 else: 497 Ik_new['a'][0] = at 498 Ik_new['phi'][0] = a['phi'] 499 Ik_new['phi_prime'][0] = a['phi_prime'] 500 Ik_new['a'][1] = al 501 Ik_new['phi'][1] = Ik['phi'][0] 502 Ik_new['phi_prime'][1] = Ik['phi_prime'][0] 503 504 if print_flag: 505 print "\tIk update, case c." 506 print "\t\tat: " + `at` 507 print "\t\ta['phi']: " + `a['phi']` 508 print "\t\ta['phi_prime']: " + `a['phi_prime']` 509 print "\t\tIk_new['a']: " + `Ik_new['a']` 510 print "\t\tIk_new['phi']: " + `Ik_new['phi']` 511 print "\t\tIk_new['phi_prime']: " + `Ik_new['phi_prime']` 512 513 return at_new, Ik_new, bracketed
514