1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24 """A line search algorithm from More and Thuente.
25
26 This file is part of the U{minfx optimisation library<https://sourceforge.net/projects/minfx>}.
27 """
28
29
30 from copy import deepcopy
31 from math import sqrt
32 from numpy import dot
33 import sys
34
35
36 from minfx.line_search.interpolate import cubic_int, cubic_ext, quadratic_fafbga, quadratic_gagb
37
38
39 cubic = cubic_int
40 quadratic = quadratic_fafbga
41 secant = quadratic_gagb
42
43
44 -def more_thuente(func, func_prime, args, x, f, g, p, a_init=1.0, a_min=1e-25, a_max=None, a_tol=1e-10, phi_min=-1e3, mu=0.001, eta=0.1, print_flag=0):
45 """A line search algorithm from More and Thuente.
46
47 More, J. J., and Thuente, D. J. 1994, Line search algorithms with guaranteed sufficient decrease. ACM Trans. Math. Softw. 20, 286-307.
48
49
50 Internal variables
51 ==================
52
53 a0, the null sequence data structure containing the following keys:
54
55 - 'a' - 0
56 - 'phi' - phi(0)
57 - 'phi_prime' - phi'(0)
58
59 a, the sequence data structure containing the following keys:
60
61 - 'a' - alpha
62 - 'phi' - phi(alpha)
63 - 'phi_prime' - phi'(alpha)
64
65 Ik, the interval data structure containing the following keys:
66
67 - 'a' - The current interval Ik = [al, au]
68 - 'phi' - The interval [phi(al), phi(au)]
69 - 'phi_prime' - The interval [phi'(al), phi'(au)]
70
71 Instead of using the modified function::
72
73 psi(a) = phi(a) - phi(0) - a.phi'(0),
74
75 the function::
76
77 psi(a) = phi(a) - a.phi'(0),
78
79 was used as the phi(0) component has no effect on the results.
80 """
81
82
83 k = 0
84 f_count = 0
85 g_count = 0
86 mod_flag = 1
87 bracketed = 0
88 a0 = {}
89 a0['a'] = 0.0
90 a0['phi'] = f
91 a0['phi_prime'] = dot(g, p)
92 if not a_min:
93 a_min = 0.0
94 if not a_max:
95 a_max = 4.0*max(1.0, a_init)
96 Ik_lim = [0.0, 5.0*a_init]
97 width = a_max - a_min
98 width2 = 2.0*width
99
100
101 a = {}
102 a['a'] = a_init
103 a['phi'] = func(*(x + a['a']*p,)+args)
104 a['phi_prime'] = dot(func_prime(*(x + a['a']*p,)+args), p)
105 f_count = f_count + 1
106 g_count = g_count + 1
107
108
109 Ik = {}
110 Ik['a'] = [0.0, 0.0]
111 Ik['phi'] = [a0['phi'], a0['phi']]
112 Ik['phi_prime'] = [a0['phi_prime'], a0['phi_prime']]
113
114 if print_flag:
115 print("\n<Line search initial values>")
116 print_data("Pre", -1, a0, Ik, Ik_lim)
117
118
119 if a0['phi_prime'] > 0.0:
120 if print_flag:
121 print("xk = " + repr(x))
122 print("fk = " + repr(f))
123 print("dfk = " + repr(g))
124 print("pk = " + repr(p))
125 print("dot(dfk, pk) = " + repr(dot(g, p)))
126 print("a0['phi_prime'] = " + repr(a0['phi_prime']))
127 raise NameError("The gradient at point 0 of this line search is positive, ie p is not a descent direction and the line search will not work.")
128 if a['a'] < a_min:
129 raise NameError("Alpha is less than alpha_min, " + repr(a['a']) + " > " + repr(a_min))
130 if a['a'] > a_max:
131 raise NameError("Alpha is greater than alpha_max, " + repr(a['a']) + " > " + repr(a_max))
132
133 while True:
134 if print_flag:
135 print("\n<Line search iteration k = " + repr(k+1) + " >")
136 print("Bracketed: " + repr(bracketed))
137 print_data("Initial", k, a, Ik, Ik_lim)
138
139
140 curv = mu * a0['phi_prime']
141 suff_dec = a0['phi'] + a['a'] * curv
142
143
144 if mod_flag:
145 if a['phi'] <= suff_dec and a['phi_prime'] >= 0.0:
146 mod_flag = 0
147
148
149 if print_flag:
150 print("Testing for convergence using the strong Wolfe conditions.")
151 if a['phi'] <= suff_dec and abs(a['phi_prime']) <= eta * abs(a0['phi_prime']):
152 if print_flag:
153 print("\tYes.")
154 print("<Line search has converged>\n")
155 return a['a'], f_count, g_count
156 if print_flag:
157 print("\tNo.")
158
159
160 if print_flag:
161 print("Testing if limits have been reached.")
162 if a['a'] == a_min:
163 if a['phi'] > suff_dec or a['phi_prime'] >= curv:
164 if print_flag:
165 print("\tYes.")
166 print("<Min alpha has been reached>\n")
167 return a['a'], f_count, g_count
168 if a['a'] == a_max:
169 if a['phi'] <= suff_dec and a['phi_prime'] <= curv:
170 if print_flag:
171 print("\tYes.")
172 print("<Max alpha has been reached>\n")
173 return a['a'], f_count, g_count
174 if print_flag:
175 print("\tNo.")
176
177 if bracketed:
178
179 if print_flag:
180 print("Testing for roundoff error.")
181 if a['a'] <= Ik_lim[0] or a['a'] >= Ik_lim[1]:
182 if print_flag:
183 print("\tYes.")
184 print("<Stopping due to roundoff error>\n")
185 return a['a'], f_count, g_count
186 if print_flag:
187 print("\tNo.")
188
189
190 if print_flag:
191 print("Testing tol.")
192 if Ik_lim[1] - Ik_lim[0] <= a_tol * Ik_lim[1]:
193 if print_flag:
194 print("\tYes.")
195 print("<Stopping tol>\n")
196 return a['a'], f_count, g_count
197 if print_flag:
198 print("\tNo.")
199
200
201 a_new = {}
202 if mod_flag and a['phi'] <= Ik['phi'][0] and a['phi'] > suff_dec:
203 if print_flag:
204 print("Choosing ak and updating the interval Ik using the modified function psi.")
205
206
207 psi = a['phi'] - curv * a['a']
208 psi_l = Ik['phi'][0] - curv * Ik['a'][0]
209 psi_u = Ik['phi'][1] - curv * Ik['a'][1]
210 psi_prime = a['phi_prime'] - curv
211 psi_l_prime = Ik['phi_prime'][0] - curv
212 psi_u_prime = Ik['phi_prime'][1] - curv
213
214 a_new['a'], Ik_new, bracketed = update(a, Ik, a['a'], Ik['a'][0], Ik['a'][1], psi, psi_l, psi_u, psi_prime, psi_l_prime, psi_u_prime, bracketed, Ik_lim, print_flag=print_flag)
215 else:
216 if print_flag:
217 print("Choosing ak and updating the interval Ik using the function phi.")
218 a_new['a'], Ik_new, bracketed = update(a, Ik, a['a'], Ik['a'][0], Ik['a'][1], a['phi'], Ik['phi'][0], Ik['phi'][1], a['phi_prime'], Ik['phi_prime'][0], Ik['phi_prime'][1], bracketed, Ik_lim, print_flag=print_flag)
219
220
221 if bracketed:
222 size = abs(Ik_new['a'][0] - Ik_new['a'][1])
223 if size >= 0.66 * width2:
224 if print_flag:
225 print("Bisection step.")
226 a_new['a'] = 0.5 * (Ik_new['a'][0] + Ik_new['a'][1])
227 width2 = width
228 width = size
229
230
231 if print_flag:
232 print("Limiting")
233 print(" Ik_lim: " + repr(Ik_lim))
234 if bracketed:
235 if print_flag:
236 print(" Bracketed.")
237 Ik_lim[0] = min(Ik_new['a'][0], Ik_new['a'][1])
238 Ik_lim[1] = max(Ik_new['a'][0], Ik_new['a'][1])
239 else:
240 if print_flag:
241 print(" Not bracketed.")
242 print(" a_new['a']: " + repr(a_new['a']))
243 print(" xtrapl: " + repr(1.1))
244 Ik_lim[0] = a_new['a'] + 1.1 * (a_new['a'] - Ik_new['a'][0])
245 Ik_lim[1] = a_new['a'] + 4.0 * (a_new['a'] - Ik_new['a'][0])
246 if print_flag:
247 print(" Ik_lim: " + repr(Ik_lim))
248
249 if bracketed:
250 if a_new['a'] <= Ik_lim[0] or a_new['a'] >= Ik_lim[1] or Ik_lim[1] - Ik_lim[0] <= a_tol * Ik_lim[1]:
251 if print_flag:
252 print("aaa")
253 a_new['a'] = Ik['a'][0]
254
255
256 if a_new['a'] < a_min:
257 if print_flag:
258 print("The step is below a_min, therefore setting the step length to a_min.")
259 a_new['a'] = a_min
260 if a_new['a'] > a_max:
261 if print_flag:
262 print("The step is above a_max, therefore setting the step length to a_max.")
263 a_new['a'] = a_max
264
265
266 if print_flag:
267 print("Calculating new values.")
268 a_new['phi'] = func(*(x + a_new['a']*p,)+args)
269 a_new['phi_prime'] = dot(func_prime(*(x + a_new['a']*p,)+args), p)
270 f_count = f_count + 1
271 g_count = g_count + 1
272
273
274 k = k + 1
275 if print_flag:
276 print("Bracketed: " + repr(bracketed))
277 print_data("Final", k, a_new, Ik_new, Ik_lim)
278 a = deepcopy(a_new)
279 Ik = deepcopy(Ik_new)
280
281
283 """Temp func for debugging."""
284
285 print(text + " data printout:")
286 print(" Iteration: " + repr(k+1))
287 print(" a: " + repr(a['a']))
288 print(" phi: " + repr(a['phi']))
289 print(" phi_prime: " + repr(a['phi_prime']))
290 print(" Ik: " + repr(Ik['a']))
291 print(" phi_I: " + repr(Ik['phi']))
292 print(" phi_I_prime: " + repr(Ik['phi_prime']))
293 print(" Ik_lim: " + repr(Ik_lim))
294
295
296 -def update(a, Ik, at, al, au, ft, fl, fu, gt, gl, gu, bracketed, Ik_lim, d=0.66, print_flag=0):
297 """Trial value selection and interval updating.
298
299 Trial value selection
300 =====================
301
302 fl, fu, ft, gl, gu, and gt are the function and gradient values at the interval end points al and au, and at the trial point at.
303 ac is the minimiser of the cubic that interpolates fl, ft, gl, and gt.
304 aq is the minimiser of the quadratic that interpolates fl, ft, and gl.
305 as is the minimiser of the quadratic that interpolates fl, gl, and gt.
306
307 The trial value selection is divided into four cases.
308
309 Case 1: ft > fl. In this case compute ac and aq, and set::
310
311 / ac, if |ac - al| < |aq - al|,
312 at+ = <
313 \ 1/2(aq + ac), otherwise.
314
315
316 Case 2: ft <= fl and gt.gl < 0. In this case compute ac and as, and set::
317
318 / ac, if |ac - at| >= |as - at|,
319 at+ = <
320 \ as, otherwise.
321
322
323 Case 3: ft <= fl and gt.gl >= 0, and |gt| <= |gl|. In this case at+ is chosen by extrapolating the function values at al and at, so the trial value at+ lies outside th interval with at and al as endpoints. Compute ac and as.
324
325 - If the cubic tends to infinity in the direction of the step and the minimum of the cubic is beyound at, set::
326
327 / ac, if |ac - at| < |as - at|,
328 at+ = <
329 \ as, otherwise.
330
331 - Otherwise set at+ = as.
332
333
334 - Redefine at+ by setting::
335
336 / min{at + d(au - at), at+}, if at > al.
337 at+ = <
338 \ max{at + d(au - at), at+}, otherwise,
339
340 - for some d < 1.
341
342
343 Case 4: ft <= fl and gt.gl >= 0, and |gt| > |gl|. In this case choose at+ as the minimiser of
344 the cubic that interpolates fu, ft, gu, and gt.
345
346
347 Interval updating
348 =================
349
350 Given a trial value at in I, the endpoints al+ and au+ of the updated interval I+ are determined
351 as follows:
352
353 - Case U1: If f(at) > f(al), then al+ = al and au+ = at.
354 - Case U2: If f(at) <= f(al) and f'(at)(al - at) > 0, then al+ = at and au+ = au.
355 - Case U3: If f(at) <= f(al) and f'(at)(al - at) < 0, then al+ = at and au+ = al.
356 """
357
358
359
360
361 if ft > fl:
362 if print_flag:
363 print("\tat selection, case 1.")
364
365 bracketed = 1
366
367
368 ac = cubic(al, at, fl, ft, gl, gt)
369 aq = quadratic(al, at, fl, ft, gl)
370 if print_flag:
371 print("\t\tac: " + repr(ac))
372 print("\t\taq: " + repr(aq))
373
374
375 if abs(ac - al) < abs(aq - al):
376 if print_flag:
377 print("\t\tabs(ac - al) < abs(aq - al), " + repr(abs(ac - al)) + " < " + repr(abs(aq - al)))
378 print("\t\tat_new = ac = " + repr(ac))
379 at_new = ac
380 else:
381 if print_flag:
382 print("\t\tabs(ac - al) >= abs(aq - al), " + repr(abs(ac - al)) + " >= " + repr(abs(aq - al)))
383 print("\t\tat_new = 1/2(aq + ac) = " + repr(0.5*(aq + ac)))
384 at_new = 0.5*(aq + ac)
385
386
387
388 elif gt * gl < 0.0:
389 if print_flag:
390 print("\tat selection, case 2.")
391
392 bracketed = 1
393
394
395 ac = cubic(al, at, fl, ft, gl, gt)
396 asec = secant(al, at, gl, gt)
397 if print_flag:
398 print("\t\tac: " + repr(ac))
399 print("\t\tasec: " + repr(asec))
400
401
402 if abs(ac - at) >= abs(asec - at):
403 if print_flag:
404 print("\t\tabs(ac - at) >= abs(asec - at), " + repr(abs(ac - at)) + " >= " + repr(abs(asec - at)))
405 print("\t\tat_new = ac = " + repr(ac))
406 at_new = ac
407 else:
408 if print_flag:
409 print("\t\tabs(ac - at) < abs(asec - at), " + repr(abs(ac - at)) + " < " + repr(abs(asec - at)))
410 print("\t\tat_new = asec = " + repr(asec))
411 at_new = asec
412
413
414
415 elif abs(gt) <= abs(gl):
416 if print_flag:
417 print("\tat selection, case 3.")
418
419
420 ac, beta1, beta2 = cubic_ext(al, at, fl, ft, gl, gt, full_output=1)
421
422 if ac > at and beta2 != 0.0:
423
424 if print_flag:
425 print("\t\tac > at and beta2 != 0.0")
426 elif at > al:
427
428 if print_flag:
429 print("\t\tat > al, " + repr(at) + " > " + repr(al))
430 ac = Ik_lim[1]
431 else:
432
433 ac = Ik_lim[0]
434
435 asec = secant(al, at, gl, gt)
436
437 if print_flag:
438 print("\t\tac: " + repr(ac))
439 print("\t\tasec: " + repr(asec))
440
441
442 if bracketed:
443 if print_flag:
444 print("\t\tBracketed")
445 if abs(ac - at) < abs(asec - at):
446 if print_flag:
447 print("\t\t\tabs(ac - at) < abs(asec - at), " + repr(abs(ac - at)) + " < " + repr(abs(asec - at)))
448 print("\t\t\tat_new = ac = " + repr(ac))
449 at_new = ac
450 else:
451 if print_flag:
452 print("\t\t\tabs(ac - at) >= abs(asec - at), " + repr(abs(ac - at)) + " >= " + repr(abs(asec - at)))
453 print("\t\t\tat_new = asec = " + repr(asec))
454 at_new = asec
455
456
457 if print_flag:
458 print("\t\tRedefining at+")
459 if at > al:
460 at_new = min(at + d*(au - at), at_new)
461 if print_flag:
462 print("\t\t\tat > al, " + repr(at) + " > " + repr(al))
463 print("\t\t\tat_new = " + repr(at_new))
464 else:
465 at_new = max(at + d*(au - at), at_new)
466 if print_flag:
467 print("\t\t\tat <= al, " + repr(at) + " <= " + repr(al))
468 print("\t\t\tat_new = " + repr(at_new))
469 else:
470 if print_flag:
471 print("\t\tNot bracketed")
472 if abs(ac - at) > abs(asec - at):
473 if print_flag:
474 print("\t\t\tabs(ac - at) > abs(asec - at), " + repr(abs(ac - at)) + " > " + repr(abs(asec - at)))
475 print("\t\t\tat_new = ac = " + repr(ac))
476 at_new = ac
477 else:
478 if print_flag:
479 print("\t\t\tabs(ac - at) <= abs(asec - at), " + repr(abs(ac - at)) + " <= " + repr(abs(asec - at)))
480 print("\t\t\tat_new = asec = " + repr(asec))
481 at_new = asec
482
483
484 if print_flag:
485 print("\t\tChecking limits.")
486 if at_new < Ik_lim[0]:
487 if print_flag:
488 print("\t\t\tat_new < Ik_lim[0], " + repr(at_new) + " < " + repr(Ik_lim[0]))
489 print("\t\t\tat_new = " + repr(Ik_lim[0]))
490 at_new = Ik_lim[0]
491 if at_new > Ik_lim[1]:
492 if print_flag:
493 print("\t\t\tat_new > Ik_lim[1], " + repr(at_new) + " > " + repr(Ik_lim[1]))
494 print("\t\t\tat_new = " + repr(Ik_lim[1]))
495 at_new = Ik_lim[1]
496
497
498
499 else:
500 if print_flag:
501 print("\tat selection, case 4.")
502 if bracketed:
503 if print_flag:
504 print("\t\tbracketed.")
505 at_new = cubic(au, at, fu, ft, gu, gt)
506 if print_flag:
507 print("\t\tat_new = " + repr(at_new))
508 elif at > al:
509 if print_flag:
510 print("\t\tnot bracketed but at > al, " + repr(at) + " > " + repr(al))
511 print("\t\tat_new = " + repr(Ik_lim[1]))
512 at_new = Ik_lim[1]
513 else:
514 if print_flag:
515 print("\t\tnot bracketed but at <= al, " + repr(at) + " <= " + repr(al))
516 print("\t\tat_new = " + repr(Ik_lim[0]))
517 at_new = Ik_lim[0]
518
519
520
521 Ik_new = deepcopy(Ik)
522
523 if ft > fl:
524 if print_flag:
525 print("\tIk update, case a, ft > fl.")
526 Ik_new['a'][1] = at
527 Ik_new['phi'][1] = a['phi']
528 Ik_new['phi_prime'][1] = a['phi_prime']
529 elif gt*(al - at) > 0.0:
530 if print_flag:
531 print("\tIk update, case b, gt*(al - at) > 0.0.")
532 Ik_new['a'][0] = at
533 Ik_new['phi'][0] = a['phi']
534 Ik_new['phi_prime'][0] = a['phi_prime']
535 else:
536 Ik_new['a'][0] = at
537 Ik_new['phi'][0] = a['phi']
538 Ik_new['phi_prime'][0] = a['phi_prime']
539 Ik_new['a'][1] = al
540 Ik_new['phi'][1] = Ik['phi'][0]
541 Ik_new['phi_prime'][1] = Ik['phi_prime'][0]
542
543 if print_flag:
544 print("\tIk update, case c.")
545 print("\t\tat: " + repr(at))
546 print("\t\ta['phi']: " + repr(a['phi']))
547 print("\t\ta['phi_prime']: " + repr(a['phi_prime']))
548 print("\t\tIk_new['a']: " + repr(Ik_new['a']))
549 print("\t\tIk_new['phi']: " + repr(Ik_new['phi']))
550 print("\t\tIk_new['phi_prime']: " + repr(Ik_new['phi_prime']))
551
552 return at_new, Ik_new, bracketed
553