1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24 from LinearAlgebra import inverse
25 from Numeric import Float64, dot, matrixmultiply, outerproduct, sqrt, zeros
26
27 from newton import Newton
28 from base_classes import Min, Trust_region
29
30
31 -def steihaug(func=None, dfunc=None, d2func=None, args=(), x0=None, func_tol=1e-25, grad_tol=None, maxiter=1e6, epsilon=1e-8, delta_max=1e5, delta0=1.0, eta=0.2, full_output=0, print_flag=0, print_prefix=""):
32 """Steihaug conjugate-gradient trust region algorithm.
33
34 Page 75 from 'Numerical Optimization' by Jorge Nocedal and Stephen J. Wright, 1999, 2nd ed.
35
36 The CG-Steihaug algorithm is:
37
38 epsilon > 0
39 p0 = 0, r0 = g, d0 = -r0
40 if ||r0|| < epsilon:
41 return p = p0
42 while 1:
43 if djT.B.dj <= 0:
44 Find tau such that p = pj + tau.dj minimises m(p) in (4.9) and satisfies ||p|| = delta
45 return p
46 aj = rjT.rj / djT.B.dj
47 pj+1 = pj + aj.dj
48 if ||pj+1|| >= delta:
49 Find tau such that p = pj + tau.dj satisfies ||p|| = delta
50 return p
51 rj+1 = rj + aj.B.dj
52 if ||rj+1|| < epsilon.||r0||:
53 return p = pj+1
54 bj+1 = rj+1T.rj+1 / rjT.rj
55 dj+1 = rj+1 + bj+1.dj
56 """
57
58 if print_flag:
59 if print_flag >= 2:
60 print print_prefix
61 print print_prefix
62 print print_prefix + "CG-Steihaug minimisation"
63 print print_prefix + "~~~~~~~~~~~~~~~~~~~~~~~~"
64 min = Steihaug(func, dfunc, d2func, args, x0, func_tol, grad_tol, maxiter, epsilon, delta_max, delta0, eta, full_output, print_flag, print_prefix)
65 results = min.minimise()
66 return results
67
68
69 -class Steihaug(Min, Trust_region, Newton):
70 - def __init__(self, func, dfunc, d2func, args, x0, func_tol, grad_tol, maxiter, epsilon, delta_max, delta0, eta, full_output, print_flag, print_prefix):
71 """Class for Steihaug conjugate-gradient trust region minimisation specific functions.
72
73 Unless you know what you are doing, you should call the function 'steihaug' rather than
74 using this class.
75 """
76
77
78 self.func = func
79 self.dfunc = dfunc
80 self.d2func = d2func
81 self.args = args
82 self.xk = x0
83 self.func_tol = func_tol
84 self.grad_tol = grad_tol
85 self.maxiter = maxiter
86 self.full_output = full_output
87 self.print_flag = print_flag
88 self.print_prefix = print_prefix
89 self.epsilon = epsilon
90 self.delta_max = delta_max
91 self.delta = delta0
92 self.eta = eta
93
94
95 self.f_count = 0
96 self.g_count = 0
97 self.h_count = 0
98
99
100 self.warning = None
101
102
103 self.setup_conv_tests()
104
105
106 self.setup_newton()
107
108
109 self.specific_update = self.update_newton
110
111
113 """The CG-Steihaug algorithm."""
114
115
116 self.pj = zeros(len(self.xk), Float64)
117 self.rj = self.dfk * 1.0
118 self.dj = -self.dfk * 1.0
119 self.B = self.d2fk * 1.0
120 len_r0 = sqrt(dot(self.rj, self.rj))
121 length_test = self.epsilon * len_r0
122
123 if self.print_flag >= 2:
124 print self.print_prefix + "p0: " + `self.pj`
125 print self.print_prefix + "r0: " + `self.rj`
126 print self.print_prefix + "d0: " + `self.dj`
127
128 if len_r0 < self.epsilon:
129 if self.print_flag >= 2:
130 print self.print_prefix + "len rj < epsilon."
131 return self.pj
132
133
134 j = 0
135 while 1:
136
137 curv = dot(self.dj, dot(self.B, self.dj))
138 if self.print_flag >= 2:
139 print self.print_prefix + "\nIteration j = " + `j`
140 print self.print_prefix + "Curv: " + `curv`
141
142
143 if curv <= 0.0:
144 tau = self.get_tau()
145 if self.print_flag >= 2:
146 print self.print_prefix + "curv <= 0.0, therefore tau = " + `tau`
147 return self.pj + tau * self.dj
148
149 aj = dot(self.rj, self.rj) / curv
150 self.pj_new = self.pj + aj * self.dj
151 if self.print_flag >= 2:
152 print self.print_prefix + "aj: " + `aj`
153 print self.print_prefix + "pj+1: " + `self.pj_new`
154
155
156 if sqrt(dot(self.pj_new, self.pj_new)) >= self.delta:
157 tau = self.get_tau()
158 if self.print_flag >= 2:
159 print self.print_prefix + "sqrt(dot(self.pj_new, self.pj_new)) >= self.delta, therefore tau = " + `tau`
160 return self.pj + tau * self.dj
161
162 self.rj_new = self.rj + aj * dot(self.B, self.dj)
163 if self.print_flag >= 2:
164 print self.print_prefix + "rj+1: " + `self.rj_new`
165
166
167 if sqrt(dot(self.rj_new, self.rj_new)) < length_test:
168 if self.print_flag >= 2:
169 print self.print_prefix + "sqrt(dot(self.rj_new, self.rj_new)) < length_test"
170 return self.pj_new
171
172 bj_new = dot(self.rj_new, self.rj_new) / dot(self.rj, self.rj)
173 self.dj_new = -self.rj_new + bj_new * self.dj
174 if self.print_flag >= 2:
175 print self.print_prefix + "len rj+1: " + `sqrt(dot(self.rj_new, self.rj_new))`
176 print self.print_prefix + "epsilon.||r0||: " + `length_test`
177 print self.print_prefix + "bj+1: " + `bj_new`
178 print self.print_prefix + "dj+1: " + `self.dj_new`
179
180
181 self.pj = self.pj_new * 1.0
182 self.rj = self.rj_new * 1.0
183 self.dj = self.dj_new * 1.0
184
185
186
187 j = j + 1
188
189
191 """Function to find tau such that p = pj + tau.dj, and ||p|| = delta."""
192
193 dot_pj_dj = dot(self.pj, self.dj)
194 len_dj_sqrd = dot(self.dj, self.dj)
195
196 tau = -dot_pj_dj + sqrt(dot_pj_dj**2 - len_dj_sqrd * (dot(self.pj, self.pj) - self.delta**2)) / len_dj_sqrd
197 return tau
198
199
201 """Find the CG-Steihaug minimiser."""
202
203
204 self.pk = self.get_pk()
205
206
207 self.xk_new = self.xk + self.pk
208 self.fk_new, self.f_count = apply(self.func, (self.xk_new,)+self.args), self.f_count + 1
209 self.dfk_new, self.g_count = apply(self.dfunc, (self.xk_new,)+self.args), self.g_count + 1
210
211
213 """Update function.
214
215 Run the trust region update. If this update decides to shift xk+1 to xk, then run the
216 Newton update.
217 """
218
219 self.trust_region_update()
220 if self.shift_flag:
221 self.specific_update()
222