Package minfx :: Package line_search :: Module more_thuente
[hide private]
[frames] | no frames]

Source Code for Module minfx.line_search.more_thuente

  1  ############################################################################### 
  2  #                                                                             # 
  3  # Copyright (C) 2003-2013 Edward d'Auvergne                                   # 
  4  #                                                                             # 
  5  # This file is part of the minfx optimisation library,                        # 
  6  # https://sourceforge.net/projects/minfx                                      # 
  7  #                                                                             # 
  8  # This program is free software: you can redistribute it and/or modify        # 
  9  # it under the terms of the GNU General Public License as published by        # 
 10  # the Free Software Foundation, either version 3 of the License, or           # 
 11  # (at your option) any later version.                                         # 
 12  #                                                                             # 
 13  # This program is distributed in the hope that it will be useful,             # 
 14  # but WITHOUT ANY WARRANTY; without even the implied warranty of              # 
 15  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               # 
 16  # GNU General Public License for more details.                                # 
 17  #                                                                             # 
 18  # You should have received a copy of the GNU General Public License           # 
 19  # along with this program.  If not, see <http://www.gnu.org/licenses/>.       # 
 20  #                                                                             # 
 21  ############################################################################### 
 22   
 23  # Module docstring. 
 24  """A line search algorithm from More and Thuente. 
 25   
 26  This file is part of the U{minfx optimisation library<https://sourceforge.net/projects/minfx>}. 
 27  """ 
 28   
 29  # Python module imports. 
 30  from copy import deepcopy 
 31  from math import sqrt 
 32  from numpy import dot 
 33  import sys 
 34   
 35  # Minfx module imports. 
 36  from minfx.line_search.interpolate import cubic_int, cubic_ext, quadratic_fafbga, quadratic_gagb 
 37   
 38  # Rename the interpolation functions. 
 39  cubic = cubic_int 
 40  quadratic = quadratic_fafbga 
 41  secant = quadratic_gagb 
 42   
 43   
44 -def more_thuente(func, func_prime, args, x, f, g, p, a_init=1.0, a_min=1e-25, a_max=None, a_tol=1e-10, phi_min=-1e3, mu=0.001, eta=0.1, print_flag=0):
45 """A line search algorithm from More and Thuente. 46 47 More, J. J., and Thuente, D. J. 1994, Line search algorithms with guaranteed sufficient decrease. ACM Trans. Math. Softw. 20, 286-307. 48 49 50 Internal variables 51 ================== 52 53 a0, the null sequence data structure containing the following keys: 54 55 - 'a' - 0 56 - 'phi' - phi(0) 57 - 'phi_prime' - phi'(0) 58 59 a, the sequence data structure containing the following keys: 60 61 - 'a' - alpha 62 - 'phi' - phi(alpha) 63 - 'phi_prime' - phi'(alpha) 64 65 Ik, the interval data structure containing the following keys: 66 67 - 'a' - The current interval Ik = [al, au] 68 - 'phi' - The interval [phi(al), phi(au)] 69 - 'phi_prime' - The interval [phi'(al), phi'(au)] 70 71 Instead of using the modified function:: 72 73 psi(a) = phi(a) - phi(0) - a.phi'(0), 74 75 the function:: 76 77 psi(a) = phi(a) - a.phi'(0), 78 79 was used as the phi(0) component has no effect on the results. 80 """ 81 82 # Initialise values. 83 k = 0 84 f_count = 0 85 g_count = 0 86 mod_flag = 1 87 bracketed = 0 88 a0 = {} 89 a0['a'] = 0.0 90 a0['phi'] = f 91 a0['phi_prime'] = dot(g, p) 92 if not a_min: 93 a_min = 0.0 94 if not a_max: 95 a_max = 4.0*max(1.0, a_init) 96 Ik_lim = [0.0, 5.0*a_init] 97 width = a_max - a_min 98 width2 = 2.0*width 99 100 # Initialise sequence data. 101 a = {} 102 a['a'] = a_init 103 a['phi'] = func(*(x + a['a']*p,)+args) 104 a['phi_prime'] = dot(func_prime(*(x + a['a']*p,)+args), p) 105 f_count = f_count + 1 106 g_count = g_count + 1 107 108 # Initialise interval data. 109 Ik = {} 110 Ik['a'] = [0.0, 0.0] 111 Ik['phi'] = [a0['phi'], a0['phi']] 112 Ik['phi_prime'] = [a0['phi_prime'], a0['phi_prime']] 113 114 if print_flag: 115 print("\n<Line search initial values>") 116 print_data("Pre", -1, a0, Ik, Ik_lim) 117 118 # Test for errors. 119 if a0['phi_prime'] > 0.0: 120 if print_flag: 121 print("xk = " + repr(x)) 122 print("fk = " + repr(f)) 123 print("dfk = " + repr(g)) 124 print("pk = " + repr(p)) 125 print("dot(dfk, pk) = " + repr(dot(g, p))) 126 print("a0['phi_prime'] = " + repr(a0['phi_prime'])) 127 raise NameError("The gradient at point 0 of this line search is positive, ie p is not a descent direction and the line search will not work.") 128 if a['a'] < a_min: 129 raise NameError("Alpha is less than alpha_min, " + repr(a['a']) + " > " + repr(a_min)) 130 if a['a'] > a_max: 131 raise NameError("Alpha is greater than alpha_max, " + repr(a['a']) + " > " + repr(a_max)) 132 133 while True: 134 if print_flag: 135 print("\n<Line search iteration k = " + repr(k+1) + " >") 136 print("Bracketed: " + repr(bracketed)) 137 print_data("Initial", k, a, Ik, Ik_lim) 138 139 # Test values. 140 curv = mu * a0['phi_prime'] 141 suff_dec = a0['phi'] + a['a'] * curv 142 143 # Modification flag, 0 - phi, 1 - psi. 144 if mod_flag: 145 if a['phi'] <= suff_dec and a['phi_prime'] >= 0.0: 146 mod_flag = 0 147 148 # Test for convergence using the strong Wolfe conditions. 149 if print_flag: 150 print("Testing for convergence using the strong Wolfe conditions.") 151 if a['phi'] <= suff_dec and abs(a['phi_prime']) <= eta * abs(a0['phi_prime']): 152 if print_flag: 153 print("\tYes.") 154 print("<Line search has converged>\n") 155 return a['a'], f_count, g_count 156 if print_flag: 157 print("\tNo.") 158 159 # Test if limits have been reached. 160 if print_flag: 161 print("Testing if limits have been reached.") 162 if a['a'] == a_min: 163 if a['phi'] > suff_dec or a['phi_prime'] >= curv: 164 if print_flag: 165 print("\tYes.") 166 print("<Min alpha has been reached>\n") 167 return a['a'], f_count, g_count 168 if a['a'] == a_max: 169 if a['phi'] <= suff_dec and a['phi_prime'] <= curv: 170 if print_flag: 171 print("\tYes.") 172 print("<Max alpha has been reached>\n") 173 return a['a'], f_count, g_count 174 if print_flag: 175 print("\tNo.") 176 177 if bracketed: 178 # Test for roundoff error. 179 if print_flag: 180 print("Testing for roundoff error.") 181 if a['a'] <= Ik_lim[0] or a['a'] >= Ik_lim[1]: 182 if print_flag: 183 print("\tYes.") 184 print("<Stopping due to roundoff error>\n") 185 return a['a'], f_count, g_count 186 if print_flag: 187 print("\tNo.") 188 189 # Test to see if a_tol has been reached. 190 if print_flag: 191 print("Testing tol.") 192 if Ik_lim[1] - Ik_lim[0] <= a_tol * Ik_lim[1]: 193 if print_flag: 194 print("\tYes.") 195 print("<Stopping tol>\n") 196 return a['a'], f_count, g_count 197 if print_flag: 198 print("\tNo.") 199 200 # Choose a safeguarded ak in set Ik which is a subset of [a_min, a_max], and update the interval Ik. 201 a_new = {} 202 if mod_flag and a['phi'] <= Ik['phi'][0] and a['phi'] > suff_dec: 203 if print_flag: 204 print("Choosing ak and updating the interval Ik using the modified function psi.") 205 206 # Calculate the modified function values and gradients at at, al, and au. 207 psi = a['phi'] - curv * a['a'] 208 psi_l = Ik['phi'][0] - curv * Ik['a'][0] 209 psi_u = Ik['phi'][1] - curv * Ik['a'][1] 210 psi_prime = a['phi_prime'] - curv 211 psi_l_prime = Ik['phi_prime'][0] - curv 212 psi_u_prime = Ik['phi_prime'][1] - curv 213 214 a_new['a'], Ik_new, bracketed = update(a, Ik, a['a'], Ik['a'][0], Ik['a'][1], psi, psi_l, psi_u, psi_prime, psi_l_prime, psi_u_prime, bracketed, Ik_lim, print_flag=print_flag) 215 else: 216 if print_flag: 217 print("Choosing ak and updating the interval Ik using the function phi.") 218 a_new['a'], Ik_new, bracketed = update(a, Ik, a['a'], Ik['a'][0], Ik['a'][1], a['phi'], Ik['phi'][0], Ik['phi'][1], a['phi_prime'], Ik['phi_prime'][0], Ik['phi_prime'][1], bracketed, Ik_lim, print_flag=print_flag) 219 220 # Bisection step. 221 if bracketed: 222 size = abs(Ik_new['a'][0] - Ik_new['a'][1]) 223 if size >= 0.66 * width2: 224 if print_flag: 225 print("Bisection step.") 226 a_new['a'] = 0.5 * (Ik_new['a'][0] + Ik_new['a'][1]) 227 width2 = width 228 width = size 229 230 # Limit. 231 if print_flag: 232 print("Limiting") 233 print(" Ik_lim: " + repr(Ik_lim)) 234 if bracketed: 235 if print_flag: 236 print(" Bracketed.") 237 Ik_lim[0] = min(Ik_new['a'][0], Ik_new['a'][1]) 238 Ik_lim[1] = max(Ik_new['a'][0], Ik_new['a'][1]) 239 else: 240 if print_flag: 241 print(" Not bracketed.") 242 print(" a_new['a']: " + repr(a_new['a'])) 243 print(" xtrapl: " + repr(1.1)) 244 Ik_lim[0] = a_new['a'] + 1.1 * (a_new['a'] - Ik_new['a'][0]) 245 Ik_lim[1] = a_new['a'] + 4.0 * (a_new['a'] - Ik_new['a'][0]) 246 if print_flag: 247 print(" Ik_lim: " + repr(Ik_lim)) 248 249 if bracketed: 250 if a_new['a'] <= Ik_lim[0] or a_new['a'] >= Ik_lim[1] or Ik_lim[1] - Ik_lim[0] <= a_tol * Ik_lim[1]: 251 if print_flag: 252 print("aaa") 253 a_new['a'] = Ik['a'][0] 254 255 # The step must be between a_min and a_max. 256 if a_new['a'] < a_min: 257 if print_flag: 258 print("The step is below a_min, therefore setting the step length to a_min.") 259 a_new['a'] = a_min 260 if a_new['a'] > a_max: 261 if print_flag: 262 print("The step is above a_max, therefore setting the step length to a_max.") 263 a_new['a'] = a_max 264 265 # Calculate new values. 266 if print_flag: 267 print("Calculating new values.") 268 a_new['phi'] = func(*(x + a_new['a']*p,)+args) 269 a_new['phi_prime'] = dot(func_prime(*(x + a_new['a']*p,)+args), p) 270 f_count = f_count + 1 271 g_count = g_count + 1 272 273 # Shift data from k+1 to k. 274 k = k + 1 275 if print_flag: 276 print("Bracketed: " + repr(bracketed)) 277 print_data("Final", k, a_new, Ik_new, Ik_lim) 278 a = deepcopy(a_new) 279 Ik = deepcopy(Ik_new)
280 281 294 295
296 -def update(a, Ik, at, al, au, ft, fl, fu, gt, gl, gu, bracketed, Ik_lim, d=0.66, print_flag=0):
297 """Trial value selection and interval updating. 298 299 Trial value selection 300 ===================== 301 302 fl, fu, ft, gl, gu, and gt are the function and gradient values at the interval end points al and au, and at the trial point at. 303 ac is the minimiser of the cubic that interpolates fl, ft, gl, and gt. 304 aq is the minimiser of the quadratic that interpolates fl, ft, and gl. 305 as is the minimiser of the quadratic that interpolates fl, gl, and gt. 306 307 The trial value selection is divided into four cases. 308 309 Case 1: ft > fl. In this case compute ac and aq, and set:: 310 311 / ac, if |ac - al| < |aq - al|, 312 at+ = < 313 \ 1/2(aq + ac), otherwise. 314 315 316 Case 2: ft <= fl and gt.gl < 0. In this case compute ac and as, and set:: 317 318 / ac, if |ac - at| >= |as - at|, 319 at+ = < 320 \ as, otherwise. 321 322 323 Case 3: ft <= fl and gt.gl >= 0, and |gt| <= |gl|. In this case at+ is chosen by extrapolating the function values at al and at, so the trial value at+ lies outside th interval with at and al as endpoints. Compute ac and as. 324 325 - If the cubic tends to infinity in the direction of the step and the minimum of the cubic is beyound at, set:: 326 327 / ac, if |ac - at| < |as - at|, 328 at+ = < 329 \ as, otherwise. 330 331 - Otherwise set at+ = as. 332 333 334 - Redefine at+ by setting:: 335 336 / min{at + d(au - at), at+}, if at > al. 337 at+ = < 338 \ max{at + d(au - at), at+}, otherwise, 339 340 - for some d < 1. 341 342 343 Case 4: ft <= fl and gt.gl >= 0, and |gt| > |gl|. In this case choose at+ as the minimiser of 344 the cubic that interpolates fu, ft, gu, and gt. 345 346 347 Interval updating 348 ================= 349 350 Given a trial value at in I, the endpoints al+ and au+ of the updated interval I+ are determined 351 as follows: 352 353 - Case U1: If f(at) > f(al), then al+ = al and au+ = at. 354 - Case U2: If f(at) <= f(al) and f'(at)(al - at) > 0, then al+ = at and au+ = au. 355 - Case U3: If f(at) <= f(al) and f'(at)(al - at) < 0, then al+ = at and au+ = al. 356 """ 357 358 # Trial value selection. 359 360 # Case 1. 361 if ft > fl: 362 if print_flag: 363 print("\tat selection, case 1.") 364 # The minimum is bracketed. 365 bracketed = 1 366 367 # Interpolation. 368 ac = cubic(al, at, fl, ft, gl, gt) 369 aq = quadratic(al, at, fl, ft, gl) 370 if print_flag: 371 print("\t\tac: " + repr(ac)) 372 print("\t\taq: " + repr(aq)) 373 374 # Return at+. 375 if abs(ac - al) < abs(aq - al): 376 if print_flag: 377 print("\t\tabs(ac - al) < abs(aq - al), " + repr(abs(ac - al)) + " < " + repr(abs(aq - al))) 378 print("\t\tat_new = ac = " + repr(ac)) 379 at_new = ac 380 else: 381 if print_flag: 382 print("\t\tabs(ac - al) >= abs(aq - al), " + repr(abs(ac - al)) + " >= " + repr(abs(aq - al))) 383 print("\t\tat_new = 1/2(aq + ac) = " + repr(0.5*(aq + ac))) 384 at_new = 0.5*(aq + ac) 385 386 387 # Case 2. 388 elif gt * gl < 0.0: 389 if print_flag: 390 print("\tat selection, case 2.") 391 # The minimum is bracketed. 392 bracketed = 1 393 394 # Interpolation. 395 ac = cubic(al, at, fl, ft, gl, gt) 396 asec = secant(al, at, gl, gt) 397 if print_flag: 398 print("\t\tac: " + repr(ac)) 399 print("\t\tasec: " + repr(asec)) 400 401 # Return at+. 402 if abs(ac - at) >= abs(asec - at): 403 if print_flag: 404 print("\t\tabs(ac - at) >= abs(asec - at), " + repr(abs(ac - at)) + " >= " + repr(abs(asec - at))) 405 print("\t\tat_new = ac = " + repr(ac)) 406 at_new = ac 407 else: 408 if print_flag: 409 print("\t\tabs(ac - at) < abs(asec - at), " + repr(abs(ac - at)) + " < " + repr(abs(asec - at))) 410 print("\t\tat_new = asec = " + repr(asec)) 411 at_new = asec 412 413 414 # Case 3. 415 elif abs(gt) <= abs(gl): 416 if print_flag: 417 print("\tat selection, case 3.") 418 419 # Interpolation. 420 ac, beta1, beta2 = cubic_ext(al, at, fl, ft, gl, gt, full_output=1) 421 422 if ac > at and beta2 != 0.0: 423 # Leave ac as ac. 424 if print_flag: 425 print("\t\tac > at and beta2 != 0.0") 426 elif at > al: 427 # Set ac to the upper limit. 428 if print_flag: 429 print("\t\tat > al, " + repr(at) + " > " + repr(al)) 430 ac = Ik_lim[1] 431 else: 432 # Set ac to the lower limit. 433 ac = Ik_lim[0] 434 435 asec = secant(al, at, gl, gt) 436 437 if print_flag: 438 print("\t\tac: " + repr(ac)) 439 print("\t\tasec: " + repr(asec)) 440 441 # Test if bracketed. 442 if bracketed: 443 if print_flag: 444 print("\t\tBracketed") 445 if abs(ac - at) < abs(asec - at): 446 if print_flag: 447 print("\t\t\tabs(ac - at) < abs(asec - at), " + repr(abs(ac - at)) + " < " + repr(abs(asec - at))) 448 print("\t\t\tat_new = ac = " + repr(ac)) 449 at_new = ac 450 else: 451 if print_flag: 452 print("\t\t\tabs(ac - at) >= abs(asec - at), " + repr(abs(ac - at)) + " >= " + repr(abs(asec - at))) 453 print("\t\t\tat_new = asec = " + repr(asec)) 454 at_new = asec 455 456 # Redefine at+. 457 if print_flag: 458 print("\t\tRedefining at+") 459 if at > al: 460 at_new = min(at + d*(au - at), at_new) 461 if print_flag: 462 print("\t\t\tat > al, " + repr(at) + " > " + repr(al)) 463 print("\t\t\tat_new = " + repr(at_new)) 464 else: 465 at_new = max(at + d*(au - at), at_new) 466 if print_flag: 467 print("\t\t\tat <= al, " + repr(at) + " <= " + repr(al)) 468 print("\t\t\tat_new = " + repr(at_new)) 469 else: 470 if print_flag: 471 print("\t\tNot bracketed") 472 if abs(ac - at) > abs(asec - at): 473 if print_flag: 474 print("\t\t\tabs(ac - at) > abs(asec - at), " + repr(abs(ac - at)) + " > " + repr(abs(asec - at))) 475 print("\t\t\tat_new = ac = " + repr(ac)) 476 at_new = ac 477 else: 478 if print_flag: 479 print("\t\t\tabs(ac - at) <= abs(asec - at), " + repr(abs(ac - at)) + " <= " + repr(abs(asec - at))) 480 print("\t\t\tat_new = asec = " + repr(asec)) 481 at_new = asec 482 483 # Check limits. 484 if print_flag: 485 print("\t\tChecking limits.") 486 if at_new < Ik_lim[0]: 487 if print_flag: 488 print("\t\t\tat_new < Ik_lim[0], " + repr(at_new) + " < " + repr(Ik_lim[0])) 489 print("\t\t\tat_new = " + repr(Ik_lim[0])) 490 at_new = Ik_lim[0] 491 if at_new > Ik_lim[1]: 492 if print_flag: 493 print("\t\t\tat_new > Ik_lim[1], " + repr(at_new) + " > " + repr(Ik_lim[1])) 494 print("\t\t\tat_new = " + repr(Ik_lim[1])) 495 at_new = Ik_lim[1] 496 497 498 # Case 4. 499 else: 500 if print_flag: 501 print("\tat selection, case 4.") 502 if bracketed: 503 if print_flag: 504 print("\t\tbracketed.") 505 at_new = cubic(au, at, fu, ft, gu, gt) 506 if print_flag: 507 print("\t\tat_new = " + repr(at_new)) 508 elif at > al: 509 if print_flag: 510 print("\t\tnot bracketed but at > al, " + repr(at) + " > " + repr(al)) 511 print("\t\tat_new = " + repr(Ik_lim[1])) 512 at_new = Ik_lim[1] 513 else: 514 if print_flag: 515 print("\t\tnot bracketed but at <= al, " + repr(at) + " <= " + repr(al)) 516 print("\t\tat_new = " + repr(Ik_lim[0])) 517 at_new = Ik_lim[0] 518 519 520 # Interval updating algorithm. 521 Ik_new = deepcopy(Ik) 522 523 if ft > fl: 524 if print_flag: 525 print("\tIk update, case a, ft > fl.") 526 Ik_new['a'][1] = at 527 Ik_new['phi'][1] = a['phi'] 528 Ik_new['phi_prime'][1] = a['phi_prime'] 529 elif gt*(al - at) > 0.0: 530 if print_flag: 531 print("\tIk update, case b, gt*(al - at) > 0.0.") 532 Ik_new['a'][0] = at 533 Ik_new['phi'][0] = a['phi'] 534 Ik_new['phi_prime'][0] = a['phi_prime'] 535 else: 536 Ik_new['a'][0] = at 537 Ik_new['phi'][0] = a['phi'] 538 Ik_new['phi_prime'][0] = a['phi_prime'] 539 Ik_new['a'][1] = al 540 Ik_new['phi'][1] = Ik['phi'][0] 541 Ik_new['phi_prime'][1] = Ik['phi_prime'][0] 542 543 if print_flag: 544 print("\tIk update, case c.") 545 print("\t\tat: " + repr(at)) 546 print("\t\ta['phi']: " + repr(a['phi'])) 547 print("\t\ta['phi_prime']: " + repr(a['phi_prime'])) 548 print("\t\tIk_new['a']: " + repr(Ik_new['a'])) 549 print("\t\tIk_new['phi']: " + repr(Ik_new['phi'])) 550 print("\t\tIk_new['phi_prime']: " + repr(Ik_new['phi_prime'])) 551 552 return at_new, Ik_new, bracketed
553