1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24 """A line search algorithm implemented using the strong Wolfe conditions.
25
26 This file is part of the U{minfx optimisation library<https://sourceforge.net/projects/minfx>}.
27 """
28
29
30 from copy import deepcopy
31 from numpy import dot, sqrt
32
33
34 from minfx.line_search.interpolate import cubic_ext, quadratic_fafbga
35
36
37 quadratic = quadratic_fafbga
38
39
40 -def nocedal_wright_wolfe(func, func_prime, args, x, f, g, p, a_init=1.0, max_a=1e5, mu=0.001, eta=0.9, tol=1e-10, print_flag=0):
41 """A line search algorithm implemented using the strong Wolfe conditions.
42
43 Algorithm 3.2, page 59, from 'Numerical Optimization' by Jorge Nocedal and Stephen J. Wright, 1999, 2nd ed.
44
45 Requires the gradient function.
46
47 These functions require serious debugging and recoding to work properly (especially the safeguarding)!
48
49
50 @param func: The function to minimise.
51 @type func: func
52 @param func_prime: The function which returns the gradient vector.
53 @type func_prime: func
54 @param args: The tuple of arguments to supply to the functions func and dfunc.
55 @type args: args
56 @param x: The parameter vector at minimisation step k.
57 @type x: numpy array
58 @param f: The function value at the point x.
59 @type f: float
60 @param g: The function gradient vector at the point x.
61 @type g: numpy array
62 @param p: The descent direction.
63 @type p: numpy array
64 @keyword a_init: Initial step length.
65 @type a_init: flaot
66 @keyword a_max: The maximum value for the step length.
67 @type a_max: float
68 @keyword mu: Constant determining the slope for the sufficient decrease condition (0 < mu < eta < 1).
69 @type mu: float
70 @keyword eta: Constant used for the Wolfe curvature condition (0 < mu < eta < 1).
71 @type eta: float
72 @keyword tol: The function tolerance.
73 @type tol: float
74 @keyword print_flag: The higher the value, the greater the amount of info printed out.
75 @type print_flag: int
76 """
77
78
79 i = 1
80 f_count = 0
81 g_count = 0
82 a0 = {}
83 a0['a'] = 0.0
84 a0['phi'] = f
85 a0['phi_prime'] = dot(g, p)
86 a_last = deepcopy(a0)
87 a_max = {}
88 a_max['a'] = max_a
89 a_max['phi'] = func(*(x + a_max['a']*p,)+args)
90 a_max['phi_prime'] = dot(func_prime(*(x + a_max['a']*p,)+args), p)
91 f_count = f_count + 1
92 g_count = g_count + 1
93
94
95 a = {}
96 a['a'] = a_init
97 a['phi'] = func(*(x + a['a']*p,)+args)
98 a['phi_prime'] = dot(func_prime(*(x + a['a']*p,)+args), p)
99 f_count = f_count + 1
100 g_count = g_count + 1
101
102 if print_flag:
103 print("\n<Line search initial values>")
104 print_data("Pre (a0)", i-1, a0)
105 print_data("Pre (a_max)", i-1, a_max)
106
107 while True:
108 if print_flag:
109 print("<Line search iteration i = " + repr(i) + " >")
110 print_data("Initial (a)", i, a)
111 print_data("Initial (a_last)", i, a_last)
112
113
114
115 if not a['phi'] <= a0['phi'] + mu * a['a'] * a0['phi_prime']:
116 if print_flag:
117 print("\tSufficient decrease condition is violated - zooming")
118 return zoom(func, func_prime, args, f_count, g_count, x, f, g, p, mu, eta, i, a0, a_last, a, tol, print_flag=print_flag)
119 if print_flag:
120 print("\tSufficient decrease condition is OK")
121
122
123 if abs(a['phi_prime']) <= -eta * a0['phi_prime']:
124 if print_flag:
125 print("\tCurvature condition OK, returning a")
126 return a['a'], f_count, g_count
127 if print_flag:
128 print("\tCurvature condition is violated")
129
130
131
132 if a['phi_prime'] >= 0.0:
133 if print_flag:
134 print("\tGradient at a['a'] is positive - zooming")
135
136 return zoom(func, func_prime, args, f_count, g_count, x, f, g, p, mu, eta, i, a0, a, a_last, tol, print_flag=print_flag)
137 if print_flag:
138 print("\tGradient is negative")
139
140
141
142
143
144 a_new = a['a'] + 0.25 * (a_max['a'] - a['a'])
145
146
147 a_last = deepcopy(a)
148 a['a'] = a_new
149 a['phi'] = func(*(x + a['a']*p,)+args)
150 a['phi_prime'] = dot(func_prime(*(x + a['a']*p,)+args), p)
151 f_count = f_count + 1
152 g_count = g_count + 1
153 i = i + 1
154 if print_flag:
155 print_data("Final (a)", i, a)
156 print_data("Final (a_last)", i, a_last)
157
158
159 if abs(a_last['phi'] - a['phi']) <= tol:
160 if print_flag:
161 print("abs(a_last['phi'] - a['phi']) <= tol")
162 return a['a'], f_count, g_count
163
164
166 """Temp func for debugging."""
167
168 print(text + " data printout:")
169 print(" Iteration: " + repr(k))
170 print(" a: " + repr(a['a']))
171 print(" phi: " + repr(a['phi']))
172 print(" phi_prime: " + repr(a['phi_prime']))
173
174
175 -def zoom(func, func_prime, args, f_count, g_count, x, f, g, p, mu, eta, i, a0, a_lo, a_hi, tol, print_flag=0):
176 """Find the minimum function value in the open interval (a_lo, a_hi)
177
178 Algorithm 3.3, page 60, from 'Numerical Optimization' by Jorge Nocedal and Stephen J. Wright, 1999, 2nd ed.
179 """
180
181
182 aj = {}
183 j = 0
184 aj_last = deepcopy(a_lo)
185
186 while True:
187 if print_flag:
188 print("\n<Zooming iterate j = " + repr(j) + " >")
189
190
191 aj_new = quadratic(a_lo['a'], a_hi['a'], a_lo['phi'], a_hi['phi'], a_lo['phi_prime'])
192
193
194 aj['a'] = max(aj_last['a'] + 0.66*(a_hi['a'] - aj_last['a']), aj_new)
195
196
197 aj['phi'] = func(*(x + aj['a']*p,)+args)
198 aj['phi_prime'] = dot(func_prime(*(x + aj['a']*p,)+args), p)
199 f_count = f_count + 1
200 g_count = g_count + 1
201
202 if print_flag:
203 print_data("a_lo", i, a_lo)
204 print_data("a_hi", i, a_hi)
205 print_data("aj", i, aj)
206
207
208 if not aj['phi'] <= a0['phi'] + mu * aj['a'] * a0['phi_prime']:
209 a_hi = deepcopy(aj)
210 else:
211
212 if abs(aj['phi_prime']) <= -eta * a0['phi_prime']:
213 if print_flag:
214 print("aj: " + repr(aj))
215 print("<Finished zooming>")
216 return aj['a'], f_count, g_count
217
218
219 if aj['phi_prime'] * (a_hi['a'] - a_lo['a']) >= 0.0:
220 a_hi = deepcopy(a_lo)
221
222 a_lo = deepcopy(aj)
223
224
225 if abs(aj_last['phi'] - aj['phi']) <= tol:
226 if print_flag:
227 print("abs(aj_last['phi'] - aj['phi']) <= tol")
228 print("<Finished zooming>")
229 return aj['a'], f_count, g_count
230
231
232 aj_last = deepcopy(aj)
233 j = j + 1
234