1 import sys
2 from copy import deepcopy
3 from math import sqrt
4 from Numeric import copy, dot
5
6 from interpolate import cubic_int, cubic_ext, quadratic_fafbga, quadratic_gagb
7
8 cubic = cubic_int
9 quadratic = quadratic_fafbga
10 secant = quadratic_gagb
11
12
13 -def more_thuente(func, func_prime, args, x, f, g, p, a_init=1.0, a_min=None, a_max=None, a_tol=1e-10, phi_min=-1e3, mu=0.001, eta=0.1, print_flag=0):
14 """A line search algorithm from More and Thuente.
15
16 More, J. J., and Thuente, D. J. 1994, Line search algorithms with guaranteed sufficient decrease.
17 ACM Trans. Math. Softw. 20, 286-307.
18
19
20 Internal variables
21 ~~~~~~~~~~~~~~~~~~
22
23 a0, the null sequence data structure containing the following keys:
24 'a' - 0
25 'phi' - phi(0)
26 'phi_prime' - phi'(0)
27
28 a, the sequence data structure containing the following keys:
29 'a' - alpha
30 'phi' - phi(alpha)
31 'phi_prime' - phi'(alpha)
32
33 Ik, the interval data structure containing the following keys:
34 'a' - The current interval Ik = [al, au]
35 'phi' - The interval [phi(al), phi(au)]
36 'phi_prime' - The interval [phi'(al), phi'(au)]
37
38 Instead of using the modified function:
39 psi(a) = phi(a) - phi(0) - a.phi'(0),
40 the function:
41 psi(a) = phi(a) - a.phi'(0),
42 was used as the phi(0) component has no effect on the results.
43 """
44
45
46 k = 0
47 f_count = 0
48 g_count = 0
49 mod_flag = 1
50 bracketed = 0
51 a0 = {}
52 a0['a'] = 0.0
53 a0['phi'] = f
54 a0['phi_prime'] = dot(g, p)
55 if not a_min:
56 a_min = 0.0
57 if not a_max:
58 a_max = 4.0*max(1.0,a_init)
59 Ik_lim = [0.0, 5.0*a_init]
60 width = a_max - a_min
61 width2 = 2.0*width
62
63
64 a = {}
65 a['a'] = a_init
66 a['phi'] = apply(func, (x + a['a']*p,)+args)
67 a['phi_prime'] = dot(apply(func_prime, (x + a['a']*p,)+args), p)
68 f_count = f_count + 1
69 g_count = g_count + 1
70
71
72 Ik = {}
73 Ik['a'] = [0.0, 0.0]
74 Ik['phi'] = [a0['phi'], a0['phi']]
75 Ik['phi_prime'] = [a0['phi_prime'], a0['phi_prime']]
76
77 if print_flag:
78 print "\n<Line search initial values>"
79 print_data("Pre", -1, a0, Ik, Ik_lim)
80
81
82 if a0['phi_prime'] > 0.0:
83 print "self.xk = " + `x`
84 print "self.fk = " + `f`
85 print "self.dfk = " + `g`
86 print "self.pk = " + `p`
87 print "dot(self.dfk, self.pk) = " + `dot(g, p)`
88 print "a0['phi_prime'] = " + `a0['phi_prime']`
89 raise NameError, "The gradient at point 0 of this line search is positive, ie p is not a descent direction and the line search will not work."
90 if a['a'] < a_min:
91 raise NameError, "Alpha is less than alpha_min, " + `a['a']` + " > " + `a_min`
92 if a['a'] > a_max:
93 raise NameError, "Alpha is greater than alpha_max, " + `a['a']` + " > " + `a_max`
94
95 while 1:
96 if print_flag:
97 print "\n<Line search iteration k = " + `k+1` + " >"
98 print "Bracketed: " + `bracketed`
99 print_data("Initial", k, a, Ik, Ik_lim)
100
101
102 curv = mu * a0['phi_prime']
103 suff_dec = a0['phi'] + a['a'] * curv
104
105
106 if mod_flag:
107 if a['phi'] <= suff_dec and a['phi_prime'] >= 0.0:
108 mod_flag = 0
109
110
111 if print_flag:
112 print "Testing for convergence using the strong Wolfe conditions."
113 if a['phi'] <= suff_dec and abs(a['phi_prime']) <= eta * abs(a0['phi_prime']):
114 if print_flag:
115 print "\tYes."
116 print "<Line search has converged>\n"
117 return a['a'], f_count, g_count
118 if print_flag:
119 print "\tNo."
120
121
122 if print_flag:
123 print "Testing if limits have been reached."
124 if a['a'] == a_min:
125 if a['phi'] > suff_dec or a['phi_prime'] >= curv:
126 if print_flag:
127 print "\tYes."
128 print "<Min alpha has been reached>\n"
129 return a['a'], f_count, g_count
130 if a['a'] == a_max:
131 if a['phi'] <= suff_dec and a['phi_prime'] <= curv:
132 if print_flag:
133 print "\tYes."
134 print "<Max alpha has been reached>\n"
135 return a['a'], f_count, g_count
136 if print_flag:
137 print "\tNo."
138
139 if bracketed:
140
141 if print_flag:
142 print "Testing for roundoff error."
143 if a['a'] <= Ik_lim[0] or a['a'] >= Ik_lim[1]:
144 if print_flag:
145 print "\tYes."
146 print "<Stopping due to roundoff error>\n"
147 return a['a'], f_count, g_count
148 if print_flag:
149 print "\tNo."
150
151
152 if print_flag:
153 print "Testing tol."
154 if Ik_lim[1] - Ik_lim[0] <= a_tol * Ik_lim[1]:
155 if print_flag:
156 print "\tYes."
157 print "<Stopping tol>\n"
158 return a['a'], f_count, g_count
159 if print_flag:
160 print "\tNo."
161
162
163 a_new = {}
164 if mod_flag and a['phi'] <= Ik['phi'][0] and a['phi'] > suff_dec:
165 if print_flag:
166 print "Choosing ak and updating the interval Ik using the modified function psi."
167
168
169 psi = a['phi'] - curv * a['a']
170 psi_l = Ik['phi'][0] - curv * Ik['a'][0]
171 psi_u = Ik['phi'][1] - curv * Ik['a'][1]
172 psi_prime = a['phi_prime'] - curv
173 psi_l_prime = Ik['phi_prime'][0] - curv
174 psi_u_prime = Ik['phi_prime'][1] - curv
175
176 a_new['a'], Ik_new, bracketed = update(a, Ik, a['a'], Ik['a'][0], Ik['a'][1], psi, psi_l, psi_u, psi_prime, psi_l_prime, psi_u_prime, bracketed, Ik_lim, print_flag=print_flag)
177 else:
178 if print_flag:
179 print "Choosing ak and updating the interval Ik using the function phi."
180 a_new['a'], Ik_new, bracketed = update(a, Ik, a['a'], Ik['a'][0], Ik['a'][1], a['phi'], Ik['phi'][0], Ik['phi'][1], a['phi_prime'], Ik['phi_prime'][0], Ik['phi_prime'][1], bracketed, Ik_lim, print_flag=print_flag)
181
182
183 if bracketed:
184 size = abs(Ik_new['a'][0] - Ik_new['a'][1])
185 if size >= 0.66 * width2:
186 if print_flag:
187 print "Bisection step."
188 a_new['a'] = 0.5 * (Ik_new['a'][0] + Ik_new['a'][1])
189 width2 = width
190 width = size
191
192
193 if print_flag:
194 print "Limiting"
195 print " Ik_lim: " + `Ik_lim`
196 if bracketed:
197 if print_flag:
198 print " Bracketed."
199 Ik_lim[0] = min(Ik_new['a'][0], Ik_new['a'][1])
200 Ik_lim[1] = max(Ik_new['a'][0], Ik_new['a'][1])
201 else:
202 if print_flag:
203 print " Not bracketed."
204 print " a_new['a']: " + `a_new['a']`
205 print " xtrapl: " + `1.1`
206 Ik_lim[0] = a_new['a'] + 1.1 * (a_new['a'] - Ik_new['a'][0])
207 Ik_lim[1] = a_new['a'] + 4.0 * (a_new['a'] - Ik_new['a'][0])
208 if print_flag:
209 print " Ik_lim: " + `Ik_lim`
210
211
212 if a_new['a'] < a_min:
213 if print_flag:
214 print "The step is below a_min, therefore setting the step length to a_min."
215 a_new['a'] = a_min
216 if a_new['a'] > a_max:
217 if print_flag:
218 print "The step is above a_max, therefore setting the step length to a_max."
219 a_new['a'] = a_max
220
221 if bracketed:
222 if a_new['a'] <= Ik_lim[0] or a_new['a'] >= Ik_lim[1] or Ik_lim[1] - Ik_lim[0] <= a_tol * Ik_lim[1]:
223 if print_flag:
224 print "aaa"
225 a_new['a'] = Ik['a'][0]
226
227
228 if print_flag:
229 print "Calculating new values."
230 a_new['phi'] = apply(func, (x + a_new['a']*p,)+args)
231 a_new['phi_prime'] = dot(apply(func_prime, (x + a_new['a']*p,)+args), p)
232 f_count = f_count + 1
233 g_count = g_count + 1
234
235
236 k = k + 1
237 if print_flag:
238 print "Bracketed: " + `bracketed`
239 print_data("Final", k, a_new, Ik_new, Ik_lim)
240 a = deepcopy(a_new)
241 Ik = deepcopy(Ik_new)
242
243
244
246 "Temp func for debugging."
247
248 print text + " data printout:"
249 print " Iteration: " + `k+1`
250 print " a: " + `a['a']`
251 print " phi: " + `a['phi']`
252 print " phi_prime: " + `a['phi_prime']`
253 print " Ik: " + `Ik['a']`
254 print " phi_I: " + `Ik['phi']`
255 print " phi_I_prime: " + `Ik['phi_prime']`
256 print " Ik_lim: " + `Ik_lim`
257
258
259 -def update(a, Ik, at, al, au, ft, fl, fu, gt, gl, gu, bracketed, Ik_lim, d=0.66, print_flag=0):
260 """Trial value selection and interval updating.
261
262 Trial value selection
263 ~~~~~~~~~~~~~~~~~~~~~
264
265 fl, fu, ft, gl, gu, and gt are the function and gradient values at the interval end points al and au, and at the trial point at.
266 ac is the minimiser of the cubic that interpolates fl, ft, gl, and gt.
267 aq is the minimiser of the quadratic that interpolates fl, ft, and gl.
268 as is the minimiser of the quadratic that interpolates fl, gl, and gt.
269
270 The trial value selection is divided into four cases.
271
272 Case 1: ft > fl. In this case compute ac and aq, and set
273
274 / ac, if |ac - al| < |aq - al|,
275 at+ = <
276 \ 1/2(aq + ac), otherwise.
277
278
279 Case 2: ft <= fl and gt.gl < 0. In this case compute ac and as, and set
280
281 / ac, if |ac - at| >= |as - at|,
282 at+ = <
283 \ as, otherwise.
284
285
286 Case 3: ft <= fl and gt.gl >= 0, and |gt| <= |gl|. In this case at+ is chosen by extrapolating the function values at al and at, so the trial value at+
287 lies outside th interval with at and al as endpoints. Compute ac and as.
288
289 If the cubic tends to infinity in the direction of the step and the minimum of the cubic is beyound at, set
290
291 / ac, if |ac - at| < |as - at|,
292 at+ = <
293 \ as, otherwise.
294
295 Otherwise set at+ = as.
296
297
298 Redefine at+ by setting
299
300 / min{at + d(au - at), at+}, if at > al.
301 at+ = <
302 \ max{at + d(au - at), at+}, otherwise,
303
304 for some d < 1.
305
306
307 Case 4: ft <= fl and gt.gl >= 0, and |gt| > |gl|. In this case choose at+ as the minimiser of the cubic that interpolates fu, ft, gu, and gt.
308
309
310 Interval updating
311 ~~~~~~~~~~~~~~~~~
312
313 Given a trial value at in I, the endpoints al+ and au+ of the updated interval I+ are determined as follows:
314 Case U1: If f(at) > f(al), then al+ = al and au+ = at.
315 Case U2: If f(at) <= f(al) and f'(at)(al - at) > 0, then al+ = at and au+ = au.
316 Case U3: If f(at) <= f(al) and f'(at)(al - at) < 0, then al+ = at and au+ = al.
317 """
318
319
320
321
322 if ft > fl:
323 if print_flag:
324 print "\tat selection, case 1."
325
326 bracketed = 1
327
328
329 ac = cubic(al, at, fl, ft, gl, gt)
330 aq = quadratic(al, at, fl, ft, gl)
331 if print_flag:
332 print "\t\tac: " + `ac`
333 print "\t\taq: " + `aq`
334
335
336 if abs(ac - al) < abs(aq - al):
337 if print_flag:
338 print "\t\tabs(ac - al) < abs(aq - al), " + `abs(ac - al)` + " < " + `abs(aq - al)`
339 print "\t\tat_new = ac = " + `ac`
340 at_new = ac
341 else:
342 if print_flag:
343 print "\t\tabs(ac - al) >= abs(aq - al), " + `abs(ac - al)` + " >= " + `abs(aq - al)`
344 print "\t\tat_new = 1/2(aq + ac) = " + `0.5*(aq + ac)`
345 at_new = 0.5*(aq + ac)
346
347
348
349 elif gt * gl < 0.0:
350 if print_flag:
351 print "\tat selection, case 2."
352
353 bracketed = 1
354
355
356 ac = cubic(al, at, fl, ft, gl, gt)
357 as = secant(al, at, gl, gt)
358 if print_flag:
359 print "\t\tac: " + `ac`
360 print "\t\tas: " + `as`
361
362
363 if abs(ac - at) >= abs(as - at):
364 if print_flag:
365 print "\t\tabs(ac - at) >= abs(as - at), " + `abs(ac - at)` + " >= " + `abs(as - at)`
366 print "\t\tat_new = ac = " + `ac`
367 at_new = ac
368 else:
369 if print_flag:
370 print "\t\tabs(ac - at) < abs(as - at), " + `abs(ac - at)` + " < " + `abs(as - at)`
371 print "\t\tat_new = as = " + `as`
372 at_new = as
373
374
375
376 elif abs(gt) <= abs(gl):
377 if print_flag:
378 print "\tat selection, case 3."
379
380
381 ac, beta1, beta2 = cubic_ext(al, at, fl, ft, gl, gt, full_output=1)
382
383 if ac > at and beta2 != 0.0:
384
385 if print_flag:
386 print "\t\tac > at and beta2 != 0.0"
387 elif at > al:
388
389 if print_flag:
390 print "\t\tat > al, " + `at` + " > " + `al`
391 ac = Ik_lim[1]
392 else:
393
394 ac = Ik_lim[0]
395
396 as = secant(al, at, gl, gt)
397
398 if print_flag:
399 print "\t\tac: " + `ac`
400 print "\t\tas: " + `as`
401
402
403 if bracketed:
404 if print_flag:
405 print "\t\tBracketed"
406 if abs(ac - at) < abs(as - at):
407 if print_flag:
408 print "\t\t\tabs(ac - at) < abs(as - at), " + `abs(ac - at)` + " < " + `abs(as - at)`
409 print "\t\t\tat_new = ac = " + `ac`
410 at_new = ac
411 else:
412 if print_flag:
413 print "\t\t\tabs(ac - at) >= abs(as - at), " + `abs(ac - at)` + " >= " + `abs(as - at)`
414 print "\t\t\tat_new = as = " + `as`
415 at_new = as
416
417
418 if print_flag:
419 print "\t\tRedefining at+"
420 if at > al:
421 at_new = min(at + d*(au - at), at_new)
422 if print_flag:
423 print "\t\t\tat > al, " + `at` + " > " + `al`
424 print "\t\t\tat_new = " + `at_new`
425 else:
426 at_new = max(at + d*(au - at), at_new)
427 if print_flag:
428 print "\t\t\tat <= al, " + `at` + " <= " + `al`
429 print "\t\t\tat_new = " + `at_new`
430 else:
431 if print_flag:
432 print "\t\tNot bracketed"
433 if abs(ac - at) > abs(as - at):
434 if print_flag:
435 print "\t\t\tabs(ac - at) > abs(as - at), " + `abs(ac - at)` + " > " + `abs(as - at)`
436 print "\t\t\tat_new = ac = " + `ac`
437 at_new = ac
438 else:
439 if print_flag:
440 print "\t\t\tabs(ac - at) <= abs(as - at), " + `abs(ac - at)` + " <= " + `abs(as - at)`
441 print "\t\t\tat_new = as = " + `as`
442 at_new = as
443
444
445 if print_flag:
446 print "\t\tChecking limits."
447 if at_new < Ik_lim[0]:
448 if print_flag:
449 print "\t\t\tat_new < Ik_lim[0], " + `at_new` + " < " + `Ik_lim[0]`
450 print "\t\t\tat_new = " + `Ik_lim[0]`
451 at_new = Ik_lim[0]
452 if at_new > Ik_lim[1]:
453 if print_flag:
454 print "\t\t\tat_new > Ik_lim[1], " + `at_new` + " > " + `Ik_lim[1]`
455 print "\t\t\tat_new = " + `Ik_lim[1]`
456 at_new = Ik_lim[1]
457
458
459
460 else:
461 if print_flag:
462 print "\tat selection, case 4."
463 if bracketed:
464 if print_flag:
465 print "\t\tbracketed."
466 at_new = cubic(au, at, fu, ft, gu, gt)
467 if print_flag:
468 print "\t\tat_new = " + `at_new`
469 elif at > al:
470 if print_flag:
471 print "\t\tnot bracketed but at > al, " + `at` + " > " + `al`
472 print "\t\tat_new = " + `Ik_lim[1]`
473 at_new = Ik_lim[1]
474 else:
475 if print_flag:
476 print "\t\tnot bracketed but at <= al, " + `at` + " <= " + `al`
477 print "\t\tat_new = " + `Ik_lim[0]`
478 at_new = Ik_lim[0]
479
480
481
482 Ik_new = deepcopy(Ik)
483
484 if ft > fl:
485 if print_flag:
486 print "\tIk update, case a, ft > fl."
487 Ik_new['a'][1] = at
488 Ik_new['phi'][1] = a['phi']
489 Ik_new['phi_prime'][1] = a['phi_prime']
490 elif gt*(al - at) > 0.0:
491 if print_flag:
492 print "\tIk update, case b, gt*(al - at) > 0.0."
493 Ik_new['a'][0] = at
494 Ik_new['phi'][0] = a['phi']
495 Ik_new['phi_prime'][0] = a['phi_prime']
496 else:
497 Ik_new['a'][0] = at
498 Ik_new['phi'][0] = a['phi']
499 Ik_new['phi_prime'][0] = a['phi_prime']
500 Ik_new['a'][1] = al
501 Ik_new['phi'][1] = Ik['phi'][0]
502 Ik_new['phi_prime'][1] = Ik['phi_prime'][0]
503
504 if print_flag:
505 print "\tIk update, case c."
506 print "\t\tat: " + `at`
507 print "\t\ta['phi']: " + `a['phi']`
508 print "\t\ta['phi_prime']: " + `a['phi_prime']`
509 print "\t\tIk_new['a']: " + `Ik_new['a']`
510 print "\t\tIk_new['phi']: " + `Ik_new['phi']`
511 print "\t\tIk_new['phi_prime']: " + `Ik_new['phi_prime']`
512
513 return at_new, Ik_new, bracketed
514