1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24 """Downhill simplex optimization.
25
26 This file is part of the minfx optimisation library at U{https://sourceforge.net/projects/minfx}.
27 """
28
29
30 from numpy import argsort, average, float64, sum, take, zeros
31
32
33 from minfx.base_classes import Min
34
35
36 -def simplex(func=None, args=(), x0=None, func_tol=1e-25, maxiter=1e6, full_output=0, print_flag=0, print_prefix=""):
37 """Downhill simplex minimisation."""
38
39 if print_flag:
40 if print_flag >= 2:
41 print(print_prefix)
42 print(print_prefix)
43 print(print_prefix + "Simplex minimisation")
44 print(print_prefix + "~~~~~~~~~~~~~~~~~~~~")
45 min = Simplex(func, args, x0, func_tol, maxiter, full_output, print_flag, print_prefix)
46 results = min.minimise()
47 return results
48
49
51 - def __init__(self, func, args, x0, func_tol, maxiter, full_output, print_flag, print_prefix):
52 """Class for downhill simplex minimisation specific functions.
53
54 Unless you know what you are doing, you should call the function 'simplex' rather than using this class.
55 """
56
57
58 self.func = func
59 self.args = args
60 self.xk = x0
61 self.func_tol = func_tol
62 self.maxiter = maxiter
63 self.full_output = full_output
64 self.print_flag = print_flag
65 self.print_prefix = print_prefix
66
67
68 self.f_count = 0
69 self.g_count = 0
70 self.h_count = 0
71
72
73 self.warning = None
74
75
76 self.n = len(self.xk)
77 self.m = self.n + 1
78
79
80 self.simplex = zeros((self.m, self.n), float64)
81 self.simplex_vals = zeros(self.m, float64)
82
83 self.simplex[0] = self.xk * 1.0
84 self.simplex_vals[0], self.f_count = self.func(*(self.xk,)+self.args), self.f_count + 1
85
86 for i in range(self.n):
87 j = i + 1
88 self.simplex[j] = self.xk
89 if self.xk[i] == 0.0:
90 self.simplex[j, i] = 2.5 * 1e-4
91 else:
92 self.simplex[j, i] = 1.05 * self.simplex[j, i]
93 self.simplex_vals[j], self.f_count = self.func(*(self.simplex[j],)+self.args), self.f_count + 1
94
95
96 self.order_simplex()
97
98
99 self.xk = self.simplex[0] * 1.0
100 self.fk = self.simplex_vals[0]
101
102
103 self.center = average(self.simplex, axis=0)
104
105
107 """The new parameter function.
108
109 Simplex movement.
110 """
111
112 self.reflect_flag = 1
113 self.shrink_flag = 0
114
115 self.pivot_point = average(self.simplex[:-1], axis=0)
116
117 self.reflect()
118 if self.reflect_val <= self.simplex_vals[0]:
119 self.extend()
120 elif self.reflect_val >= self.simplex_vals[-2]:
121 self.reflect_flag = 0
122 if self.reflect_val < self.simplex_vals[-1]:
123 self.contract()
124 else:
125 self.contract_orig()
126 if self.reflect_flag:
127 self.simplex[-1], self.simplex_vals[-1] = self.reflect_vector, self.reflect_val
128 if self.shrink_flag:
129 self.shrink()
130
131 self.order_simplex()
132
133
134 self.xk_new = self.simplex[0]
135 self.fk_new = self.simplex_vals[0]
136 self.dfk_new = None
137
138
139 self.center_new = average(self.simplex, axis=0)
140 self.dist = sum(abs(self.center_new - self.center), axis=0)
141
142
144 """Contraction step."""
145
146 self.contract_vector = 1.5 * self.pivot_point - 0.5 * self.simplex[-1]
147 self.contract_val, self.f_count = self.func(*(self.contract_vector,)+self.args), self.f_count + 1
148 if self.contract_val < self.reflect_val:
149 self.simplex[-1], self.simplex_vals[-1] = self.contract_vector, self.contract_val
150 else:
151 self.shrink_flag = 1
152
153
155 """Contraction of the original simplex."""
156
157 self.contract_orig_vector = 0.5 * (self.pivot_point + self.simplex[-1])
158 self.contract_orig_val, self.f_count = self.func(*(self.contract_orig_vector,)+self.args), self.f_count + 1
159 if self.contract_orig_val < self.simplex_vals[-1]:
160 self.simplex[-1], self.simplex_vals[-1] = self.contract_orig_vector, self.contract_orig_val
161 else:
162 self.shrink_flag = 1
163
164
166 """Extension step."""
167
168 self.extend_vector = 3.0 * self.pivot_point - 2.0 * self.simplex[-1]
169 self.extend_val, self.f_count = self.func(*(self.extend_vector,)+self.args), self.f_count + 1
170 if self.extend_val < self.reflect_val:
171 self.simplex[-1], self.simplex_vals[-1] = self.extend_vector, self.extend_val
172 self.reflect_flag = 0
173
174
176 """Order the vertices of the simplex according to ascending function values."""
177
178 sorted = argsort(self.simplex_vals)
179 self.simplex = take(self.simplex, sorted, axis=0)
180 self.simplex_vals = take(self.simplex_vals, sorted, axis=0)
181
182
184 """Reflection step."""
185
186 self.reflect_vector = 2.0 * self.pivot_point - self.simplex[-1]
187 self.reflect_val, self.f_count = self.func(*(self.reflect_vector,)+self.args), self.f_count + 1
188
189
191 """Shrinking step."""
192
193 for i in range(self.n):
194 j = i + 1
195 self.simplex[j] = 0.5 * (self.simplex[0] + self.simplex[j])
196 self.simplex_vals[j], self.f_count = self.func(*(self.simplex[j],)+self.args), self.f_count + 1
197
198
200 """Convergence test.
201
202 Finish minimising when the function difference between the highest and lowest simplex vertices is insignificant or if the simplex doesn't move.
203 """
204
205 if self.print_flag >= 2:
206 print(self.print_prefix + "diff = " + repr(self.simplex_vals[-1] - self.simplex_vals[0]))
207 print(self.print_prefix + "|diff| = " + repr(abs(self.simplex_vals[-1] - self.simplex_vals[0])))
208 print(self.print_prefix + "f_tol = " + repr(self.func_tol))
209 print(self.print_prefix + "center = " + repr(self.pivot_point))
210 try:
211 print(self.print_prefix + "old center = " + repr(self.old_pivot))
212 print(self.print_prefix + "center diff = " + repr(self.pivot_point - self.old_pivot))
213 except AttributeError:
214 pass
215 self.old_pivot = 1.0 * self.pivot_point
216 if abs(self.simplex_vals[-1] - self.simplex_vals[0]) <= self.func_tol:
217 if self.print_flag >= 2:
218 print("\n" + self.print_prefix + "???Function tolerance reached.")
219 print(self.print_prefix + "simplex_vals[-1]: " + repr(self.simplex_vals[-1]))
220 print(self.print_prefix + "simplex_vals[0]: " + repr(self.simplex_vals[0]))
221 print(self.print_prefix + "|diff|: " + repr(abs(self.simplex_vals[-1] - self.simplex_vals[0])))
222 print(self.print_prefix + "tol: " + repr(self.func_tol))
223 self.xk_new = self.simplex[0]
224 self.fk_new = self.simplex_vals[0]
225 return 1
226
227
228 if self.dist == 0.0:
229 self.warning = "Simplex has not moved."
230 self.xk_new = self.simplex[0]
231 self.fk_new = self.simplex_vals[0]
232 return 1
233
234
236 """Update function."""
237
238 self.xk = self.xk_new * 1.0
239 self.fk = self.fk_new
240 self.center = self.center_new * 1.0
241