1   
  2   
  3   
  4   
  5   
  6   
  7   
  8   
  9   
 10   
 11   
 12   
 13   
 14   
 15   
 16   
 17   
 18   
 19   
 20   
 21   
 22   
 23   
 24  """A line search algorithm from More and Thuente. 
 25   
 26  This file is part of the U{minfx optimisation library<https://gna.org/projects/minfx>}. 
 27  """ 
 28   
 29   
 30  from copy import deepcopy 
 31  from math import sqrt 
 32  from numpy import dot 
 33  import sys 
 34   
 35   
 36  from minfx.line_search.interpolate import cubic_int, cubic_ext, quadratic_fafbga, quadratic_gagb 
 37   
 38   
 39  cubic = cubic_int 
 40  quadratic = quadratic_fafbga 
 41  secant = quadratic_gagb 
 42   
 43   
 44 -def more_thuente(func, func_prime, args, x, f, g, p, a_init=1.0, a_min=1e-25, a_max=None, a_tol=1e-10, phi_min=-1e3, mu=0.001, eta=0.1, print_flag=0): 
  45      """A line search algorithm from More and Thuente. 
 46   
 47      More, J. J., and Thuente, D. J. 1994, Line search algorithms with guaranteed sufficient decrease. ACM Trans. Math. Softw. 20, 286-307. 
 48   
 49   
 50      Internal variables 
 51      ================== 
 52   
 53      a0, the null sequence data structure containing the following keys: 
 54   
 55          - 'a'        - 0 
 56          - 'phi'        - phi(0) 
 57          - 'phi_prime'    - phi'(0) 
 58   
 59      a, the sequence data structure containing the following keys: 
 60   
 61          - 'a'        - alpha 
 62          - 'phi'        - phi(alpha) 
 63          - 'phi_prime'    - phi'(alpha) 
 64   
 65      Ik, the interval data structure containing the following keys: 
 66   
 67          - 'a'        - The current interval Ik = [al, au] 
 68          - 'phi'        - The interval [phi(al), phi(au)] 
 69          - 'phi_prime'    - The interval [phi'(al), phi'(au)] 
 70   
 71      Instead of using the modified function:: 
 72   
 73          psi(a) = phi(a) - phi(0) - a.phi'(0), 
 74   
 75      the function:: 
 76   
 77          psi(a) = phi(a) - a.phi'(0), 
 78   
 79      was used as the phi(0) component has no effect on the results. 
 80      """ 
 81   
 82       
 83      k = 0 
 84      f_count = 0 
 85      g_count = 0 
 86      mod_flag = 1 
 87      bracketed = 0 
 88      a0 = {} 
 89      a0['a'] = 0.0 
 90      a0['phi'] = f 
 91      a0['phi_prime'] = dot(g, p) 
 92      if not a_min: 
 93          a_min = 0.0 
 94      if not a_max: 
 95          a_max = 4.0*max(1.0, a_init) 
 96      Ik_lim = [0.0, 5.0*a_init] 
 97      width = a_max - a_min 
 98      width2 = 2.0*width 
 99   
100       
101      a = {} 
102      a['a'] = a_init 
103      a['phi'] = func(*(x + a['a']*p,)+args) 
104      a['phi_prime'] = dot(func_prime(*(x + a['a']*p,)+args), p) 
105      f_count = f_count + 1 
106      g_count = g_count + 1 
107   
108       
109      Ik = {} 
110      Ik['a'] = [0.0, 0.0] 
111      Ik['phi'] = [a0['phi'], a0['phi']] 
112      Ik['phi_prime'] = [a0['phi_prime'], a0['phi_prime']] 
113   
114      if print_flag: 
115          print("\n<Line search initial values>") 
116          print_data("Pre", -1, a0, Ik, Ik_lim) 
117   
118       
119      if a0['phi_prime'] > 0.0: 
120          if print_flag: 
121              print("xk = " + repr(x)) 
122              print("fk = " + repr(f)) 
123              print("dfk = " + repr(g)) 
124              print("pk = " + repr(p)) 
125              print("dot(dfk, pk) = " + repr(dot(g, p))) 
126              print("a0['phi_prime'] = " + repr(a0['phi_prime'])) 
127          raise NameError("The gradient at point 0 of this line search is positive, ie p is not a descent direction and the line search will not work.") 
128      if a['a'] < a_min: 
129          raise NameError("Alpha is less than alpha_min, " + repr(a['a']) + " > " + repr(a_min)) 
130      if a['a'] > a_max: 
131          raise NameError("Alpha is greater than alpha_max, " + repr(a['a']) + " > " + repr(a_max)) 
132   
133      while True: 
134          if print_flag: 
135              print("\n<Line search iteration k = " + repr(k+1) + " >") 
136              print("Bracketed: " + repr(bracketed)) 
137              print_data("Initial", k, a, Ik, Ik_lim) 
138   
139           
140          curv = mu * a0['phi_prime'] 
141          suff_dec = a0['phi'] + a['a'] * curv 
142   
143           
144          if mod_flag: 
145              if a['phi'] <= suff_dec and a['phi_prime'] >= 0.0: 
146                  mod_flag = 0 
147   
148           
149          if print_flag: 
150              print("Testing for convergence using the strong Wolfe conditions.") 
151          if a['phi'] <= suff_dec and abs(a['phi_prime']) <= eta * abs(a0['phi_prime']): 
152              if print_flag: 
153                  print("\tYes.") 
154                  print("<Line search has converged>\n") 
155              return a['a'], f_count, g_count 
156          if print_flag: 
157              print("\tNo.") 
158   
159           
160          if print_flag: 
161              print("Testing if limits have been reached.") 
162          if a['a'] == a_min: 
163              if a['phi'] > suff_dec or a['phi_prime'] >= curv: 
164                  if print_flag: 
165                      print("\tYes.") 
166                      print("<Min alpha has been reached>\n") 
167                  return a['a'], f_count, g_count 
168          if a['a'] == a_max: 
169              if a['phi'] <= suff_dec and a['phi_prime'] <= curv: 
170                  if print_flag: 
171                      print("\tYes.") 
172                      print("<Max alpha has been reached>\n") 
173                  return a['a'], f_count, g_count 
174          if print_flag: 
175              print("\tNo.") 
176   
177          if bracketed: 
178               
179              if print_flag: 
180                  print("Testing for roundoff error.") 
181              if a['a'] <= Ik_lim[0] or a['a'] >= Ik_lim[1]: 
182                  if print_flag: 
183                      print("\tYes.") 
184                      print("<Stopping due to roundoff error>\n") 
185                  return a['a'], f_count, g_count 
186              if print_flag: 
187                  print("\tNo.") 
188   
189               
190              if print_flag: 
191                  print("Testing tol.") 
192              if Ik_lim[1] - Ik_lim[0] <= a_tol * Ik_lim[1]: 
193                  if print_flag: 
194                      print("\tYes.") 
195                      print("<Stopping tol>\n") 
196                  return a['a'], f_count, g_count 
197              if print_flag: 
198                  print("\tNo.") 
199   
200           
201          a_new = {} 
202          if mod_flag and a['phi'] <= Ik['phi'][0] and a['phi'] > suff_dec: 
203              if print_flag: 
204                  print("Choosing ak and updating the interval Ik using the modified function psi.") 
205   
206               
207              psi = a['phi'] - curv * a['a'] 
208              psi_l = Ik['phi'][0] - curv * Ik['a'][0] 
209              psi_u = Ik['phi'][1] - curv * Ik['a'][1] 
210              psi_prime = a['phi_prime'] - curv 
211              psi_l_prime = Ik['phi_prime'][0] - curv 
212              psi_u_prime = Ik['phi_prime'][1] - curv 
213   
214              a_new['a'], Ik_new, bracketed = update(a, Ik, a['a'], Ik['a'][0], Ik['a'][1], psi, psi_l, psi_u, psi_prime, psi_l_prime, psi_u_prime, bracketed, Ik_lim, print_flag=print_flag) 
215          else: 
216              if print_flag: 
217                  print("Choosing ak and updating the interval Ik using the function phi.") 
218              a_new['a'], Ik_new, bracketed = update(a, Ik, a['a'], Ik['a'][0], Ik['a'][1], a['phi'], Ik['phi'][0], Ik['phi'][1], a['phi_prime'], Ik['phi_prime'][0], Ik['phi_prime'][1], bracketed, Ik_lim, print_flag=print_flag) 
219   
220           
221          if bracketed: 
222              size = abs(Ik_new['a'][0] - Ik_new['a'][1]) 
223              if size >= 0.66 * width2: 
224                  if print_flag: 
225                      print("Bisection step.") 
226                  a_new['a'] = 0.5 * (Ik_new['a'][0] + Ik_new['a'][1]) 
227              width2 = width 
228              width = size 
229   
230           
231          if print_flag: 
232              print("Limiting") 
233              print("   Ik_lim: " + repr(Ik_lim)) 
234          if bracketed: 
235              if print_flag: 
236                  print("   Bracketed.") 
237              Ik_lim[0] = min(Ik_new['a'][0], Ik_new['a'][1]) 
238              Ik_lim[1] = max(Ik_new['a'][0], Ik_new['a'][1]) 
239          else: 
240              if print_flag: 
241                  print("   Not bracketed.") 
242                  print("   a_new['a']: " + repr(a_new['a'])) 
243                  print("   xtrapl:     " + repr(1.1)) 
244              Ik_lim[0] = a_new['a'] + 1.1 * (a_new['a'] - Ik_new['a'][0]) 
245              Ik_lim[1] = a_new['a'] + 4.0 * (a_new['a'] - Ik_new['a'][0]) 
246          if print_flag: 
247              print("   Ik_lim: " + repr(Ik_lim)) 
248   
249          if bracketed: 
250              if a_new['a'] <= Ik_lim[0] or a_new['a'] >= Ik_lim[1] or Ik_lim[1] - Ik_lim[0] <= a_tol * Ik_lim[1]: 
251                  if print_flag: 
252                      print("aaa") 
253                  a_new['a'] = Ik['a'][0] 
254   
255           
256          if a_new['a'] < a_min: 
257              if print_flag: 
258                  print("The step is below a_min, therefore setting the step length to a_min.") 
259              a_new['a'] = a_min 
260          if a_new['a'] > a_max: 
261              if print_flag: 
262                  print("The step is above a_max, therefore setting the step length to a_max.") 
263              a_new['a'] = a_max 
264   
265           
266          if print_flag: 
267              print("Calculating new values.") 
268          a_new['phi'] = func(*(x + a_new['a']*p,)+args) 
269          a_new['phi_prime'] = dot(func_prime(*(x + a_new['a']*p,)+args), p) 
270          f_count = f_count + 1 
271          g_count = g_count + 1 
272   
273           
274          k = k + 1 
275          if print_flag: 
276              print("Bracketed: " + repr(bracketed)) 
277              print_data("Final", k, a_new, Ik_new, Ik_lim) 
278          a = deepcopy(a_new) 
279          Ik = deepcopy(Ik_new) 
 280   
281   
283      """Temp func for debugging.""" 
284   
285      print(text + " data printout:") 
286      print("   Iteration:   " + repr(k+1)) 
287      print("   a:           " + repr(a['a'])) 
288      print("   phi:         " + repr(a['phi'])) 
289      print("   phi_prime:   " + repr(a['phi_prime'])) 
290      print("   Ik:          " + repr(Ik['a'])) 
291      print("   phi_I:       " + repr(Ik['phi'])) 
292      print("   phi_I_prime: " + repr(Ik['phi_prime'])) 
293      print("   Ik_lim:      " + repr(Ik_lim)) 
 294   
295   
296 -def update(a, Ik, at, al, au, ft, fl, fu, gt, gl, gu, bracketed, Ik_lim, d=0.66, print_flag=0): 
 297      """Trial value selection and interval updating. 
298   
299      Trial value selection 
300      ===================== 
301   
302      fl, fu, ft, gl, gu, and gt are the function and gradient values at the interval end points al and au, and at the trial point at. 
303      ac is the minimiser of the cubic that interpolates fl, ft, gl, and gt. 
304      aq is the minimiser of the quadratic that interpolates fl, ft, and gl. 
305      as is the minimiser of the quadratic that interpolates fl, gl, and gt. 
306   
307      The trial value selection is divided into four cases. 
308   
309      Case 1: ft > fl.  In this case compute ac and aq, and set:: 
310   
311                 / ac,            if |ac - al| < |aq - al|, 
312          at+ = < 
313                 \ 1/2(aq + ac),  otherwise. 
314   
315   
316      Case 2: ft <= fl and gt.gl < 0.  In this case compute ac and as, and set:: 
317   
318                 / ac,            if |ac - at| >= |as - at|, 
319          at+ = < 
320                 \ as,            otherwise. 
321   
322   
323      Case 3: ft <= fl and gt.gl >= 0, and |gt| <= |gl|.  In this case at+ is chosen by extrapolating the function values at al and at, so the trial value at+ lies outside th interval with at and al as endpoints.  Compute ac and as. 
324   
325          - If the cubic tends to infinity in the direction of the step and the minimum of the cubic is beyound at, set:: 
326   
327                     / ac,            if |ac - at| < |as - at|, 
328              at+ = < 
329                     \ as,            otherwise. 
330   
331          - Otherwise set at+ = as. 
332   
333   
334          - Redefine at+ by setting:: 
335   
336                     / min{at + d(au - at), at+},        if at > al. 
337              at+ = < 
338                     \ max{at + d(au - at), at+},        otherwise, 
339   
340          - for some d < 1. 
341   
342   
343      Case 4: ft <= fl and gt.gl >= 0, and |gt| > |gl|.  In this case choose at+ as the minimiser of 
344      the cubic that interpolates fu, ft, gu, and gt. 
345   
346   
347      Interval updating 
348      ================= 
349   
350      Given a trial value at in I, the endpoints al+ and au+ of the updated interval I+ are determined 
351      as follows: 
352   
353          - Case U1: If f(at) > f(al), then al+ = al and au+ = at. 
354          - Case U2: If f(at) <= f(al) and f'(at)(al - at) > 0, then al+ = at and au+ = au. 
355          - Case U3: If f(at) <= f(al) and f'(at)(al - at) < 0, then al+ = at and au+ = al. 
356      """ 
357   
358       
359   
360       
361      if ft > fl: 
362          if print_flag: 
363              print("\tat selection, case 1.") 
364           
365          bracketed = 1 
366   
367           
368          ac = cubic(al, at, fl, ft, gl, gt) 
369          aq = quadratic(al, at, fl, ft, gl) 
370          if print_flag: 
371              print("\t\tac: " + repr(ac)) 
372              print("\t\taq: " + repr(aq)) 
373   
374           
375          if abs(ac - al) < abs(aq - al): 
376              if print_flag: 
377                  print("\t\tabs(ac - al) < abs(aq - al), " + repr(abs(ac - al)) + " < " + repr(abs(aq - al))) 
378                  print("\t\tat_new = ac = " + repr(ac)) 
379              at_new = ac 
380          else: 
381              if print_flag: 
382                  print("\t\tabs(ac - al) >= abs(aq - al), " + repr(abs(ac - al)) + " >= " + repr(abs(aq - al))) 
383                  print("\t\tat_new = 1/2(aq + ac) = " + repr(0.5*(aq + ac))) 
384              at_new = 0.5*(aq + ac) 
385   
386   
387       
388      elif gt * gl < 0.0: 
389          if print_flag: 
390              print("\tat selection, case 2.") 
391           
392          bracketed = 1 
393   
394           
395          ac = cubic(al, at, fl, ft, gl, gt) 
396          asec = secant(al, at, gl, gt) 
397          if print_flag: 
398              print("\t\tac: " + repr(ac)) 
399              print("\t\tasec: " + repr(asec)) 
400   
401           
402          if abs(ac - at) >= abs(asec - at): 
403              if print_flag: 
404                  print("\t\tabs(ac - at) >= abs(asec - at), " + repr(abs(ac - at)) + " >= " + repr(abs(asec - at))) 
405                  print("\t\tat_new = ac = " + repr(ac)) 
406              at_new = ac 
407          else: 
408              if print_flag: 
409                  print("\t\tabs(ac - at) < abs(asec - at), " + repr(abs(ac - at)) + " < " + repr(abs(asec - at))) 
410                  print("\t\tat_new = asec = " + repr(asec)) 
411              at_new = asec 
412   
413   
414       
415      elif abs(gt) <= abs(gl): 
416          if print_flag: 
417              print("\tat selection, case 3.") 
418   
419           
420          ac, beta1, beta2 = cubic_ext(al, at, fl, ft, gl, gt, full_output=1) 
421   
422          if ac > at and beta2 != 0.0: 
423               
424              if print_flag: 
425                  print("\t\tac > at and beta2 != 0.0") 
426          elif at > al: 
427               
428              if print_flag: 
429                  print("\t\tat > al, " + repr(at) + " > " + repr(al)) 
430              ac = Ik_lim[1] 
431          else: 
432               
433              ac = Ik_lim[0] 
434   
435          asec = secant(al, at, gl, gt) 
436   
437          if print_flag: 
438              print("\t\tac: " + repr(ac)) 
439              print("\t\tasec: " + repr(asec)) 
440   
441           
442          if bracketed: 
443              if print_flag: 
444                  print("\t\tBracketed") 
445              if abs(ac - at) < abs(asec - at): 
446                  if print_flag: 
447                      print("\t\t\tabs(ac - at) < abs(asec - at), " + repr(abs(ac - at)) + " < " + repr(abs(asec - at))) 
448                      print("\t\t\tat_new = ac = " + repr(ac)) 
449                  at_new = ac 
450              else: 
451                  if print_flag: 
452                      print("\t\t\tabs(ac - at) >= abs(asec - at), " + repr(abs(ac - at)) + " >= " + repr(abs(asec - at))) 
453                      print("\t\t\tat_new = asec = " + repr(asec)) 
454                  at_new = asec 
455   
456               
457              if print_flag: 
458                  print("\t\tRedefining at+") 
459              if at > al: 
460                  at_new = min(at + d*(au - at), at_new) 
461                  if print_flag: 
462                      print("\t\t\tat > al, " + repr(at) + " > " + repr(al)) 
463                      print("\t\t\tat_new = " + repr(at_new)) 
464              else: 
465                  at_new = max(at + d*(au - at), at_new) 
466                  if print_flag: 
467                      print("\t\t\tat <= al, " + repr(at) + " <= " + repr(al)) 
468                      print("\t\t\tat_new = " + repr(at_new)) 
469          else: 
470              if print_flag: 
471                  print("\t\tNot bracketed") 
472              if abs(ac - at) > abs(asec - at): 
473                  if print_flag: 
474                      print("\t\t\tabs(ac - at) > abs(asec - at), " + repr(abs(ac - at)) + " > " + repr(abs(asec - at))) 
475                      print("\t\t\tat_new = ac = " + repr(ac)) 
476                  at_new = ac 
477              else: 
478                  if print_flag: 
479                      print("\t\t\tabs(ac - at) <= abs(asec - at), " + repr(abs(ac - at)) + " <= " + repr(abs(asec - at))) 
480                      print("\t\t\tat_new = asec = " + repr(asec)) 
481                  at_new = asec 
482   
483               
484              if print_flag: 
485                  print("\t\tChecking limits.") 
486              if at_new < Ik_lim[0]: 
487                  if print_flag: 
488                      print("\t\t\tat_new < Ik_lim[0], " + repr(at_new) + " < " + repr(Ik_lim[0])) 
489                      print("\t\t\tat_new = " + repr(Ik_lim[0])) 
490                  at_new = Ik_lim[0] 
491              if at_new > Ik_lim[1]: 
492                  if print_flag: 
493                      print("\t\t\tat_new > Ik_lim[1], " + repr(at_new) + " > " + repr(Ik_lim[1])) 
494                      print("\t\t\tat_new = " + repr(Ik_lim[1])) 
495                  at_new = Ik_lim[1] 
496   
497   
498       
499      else: 
500          if print_flag: 
501              print("\tat selection, case 4.") 
502          if bracketed: 
503              if print_flag: 
504                  print("\t\tbracketed.") 
505              at_new = cubic(au, at, fu, ft, gu, gt) 
506              if print_flag: 
507                  print("\t\tat_new = " + repr(at_new)) 
508          elif at > al: 
509              if print_flag: 
510                  print("\t\tnot bracketed but at > al, " + repr(at) + " > " + repr(al)) 
511                  print("\t\tat_new = " + repr(Ik_lim[1])) 
512              at_new = Ik_lim[1] 
513          else: 
514              if print_flag: 
515                  print("\t\tnot bracketed but at <= al, " + repr(at) + " <= " + repr(al)) 
516                  print("\t\tat_new = " + repr(Ik_lim[0])) 
517              at_new = Ik_lim[0] 
518   
519   
520       
521      Ik_new = deepcopy(Ik) 
522   
523      if ft > fl: 
524          if print_flag: 
525              print("\tIk update, case a, ft > fl.") 
526          Ik_new['a'][1] = at 
527          Ik_new['phi'][1] = a['phi'] 
528          Ik_new['phi_prime'][1] = a['phi_prime'] 
529      elif gt*(al - at) > 0.0: 
530          if print_flag: 
531              print("\tIk update, case b, gt*(al - at) > 0.0.") 
532          Ik_new['a'][0] = at 
533          Ik_new['phi'][0] = a['phi'] 
534          Ik_new['phi_prime'][0] = a['phi_prime'] 
535      else: 
536          Ik_new['a'][0] = at 
537          Ik_new['phi'][0] = a['phi'] 
538          Ik_new['phi_prime'][0] = a['phi_prime'] 
539          Ik_new['a'][1] = al 
540          Ik_new['phi'][1] = Ik['phi'][0] 
541          Ik_new['phi_prime'][1] = Ik['phi_prime'][0] 
542   
543          if print_flag: 
544              print("\tIk update, case c.") 
545              print("\t\tat:                  " + repr(at)) 
546              print("\t\ta['phi']:            " + repr(a['phi'])) 
547              print("\t\ta['phi_prime']:      " + repr(a['phi_prime'])) 
548              print("\t\tIk_new['a']:         " + repr(Ik_new['a'])) 
549              print("\t\tIk_new['phi']:       " + repr(Ik_new['phi'])) 
550              print("\t\tIk_new['phi_prime']: " + repr(Ik_new['phi_prime'])) 
551   
552      return at_new, Ik_new, bracketed 
 553