1 import sys
2
3 try:
4 from Numeric import Float64, add, argsort, average, copy, take, zeros
5 except ImportError:
6 raise NameError, "Please upgrade Numeric."
7
8 from generic_minimise import generic_minimise
9
10
12 - def __init__(self, func, args=(), x0=None, func_tol=1e-5, maxiter=None, full_output=0, print_flag=0):
13 "Class for downhill simplex minimisation specific functions."
14
15 self.func = func
16 self.args = args
17 self.xk = x0
18 self.func_tol = func_tol
19 self.maxiter = maxiter
20 self.full_output = full_output
21 self.print_flag = print_flag
22
23
24 self.f_count = 0
25 self.g_count = 0
26 self.h_count = 0
27
28
29 self.warning = None
30
31
32 self.n = len(self.xk)
33 self.m = self.n + 1
34
35
36 self.create_simplex()
37 self.order_simplex()
38
39 self.xk = self.simplex[0]
40 self.fk = self.simplex_vals[0]
41
42
43 self.minimise = self.generic_minimise
44
45
47 "Simplex movement."
48
49 self.reflect_flag = 1
50 self.shrink_flag = 0
51
52 self.pivot_point = average(self.simplex[:-1])
53
54 self.reflect()
55 if self.reflect_val <= self.simplex_vals[0]:
56 self.extend()
57 elif self.reflect_val >= self.simplex_vals[-2]:
58 self.reflect_flag = 0
59 if self.reflect_val < self.simplex_vals[-1]:
60 self.contract()
61 else:
62 self.contract_orig()
63 if self.reflect_flag:
64 self.simplex[-1], self.simplex_vals[-1] = self.reflect_vector, self.reflect_val
65 if self.shrink_flag:
66 self.shrink()
67
68 self.order_simplex()
69
70 self.xk = self.simplex[0]
71 self.fk = self.simplex_vals[0]
72
73
75 "Contraction step."
76
77 self.contract_vector = 1.5 * self.pivot_point - 0.5 * self.simplex[-1]
78 self.contract_val, self.f_count = apply(self.func, (self.contract_vector,)+self.args), self.f_count + 1
79 if self.contract_val < self.reflect_val:
80 self.simplex[-1], self.simplex_vals[-1] = self.contract_vector, self.contract_val
81 else:
82 self.shrink_flag = 1
83
84
86 "Contraction of the original simplex."
87
88 self.contract_orig_vector = 0.5 * (self.pivot_point + self.simplex[-1])
89 self.contract_orig_val, self.f_count = apply(self.func, (self.contract_orig_vector,)+self.args), self.f_count + 1
90 if self.contract_orig_val < self.simplex_vals[-1]:
91 self.simplex[-1], self.simplex_vals[-1] = self.contract_orig_vector, self.contract_orig_val
92 else:
93 self.shrink_flag = 1
94
95
97 """Function to create the initial simplex and calculate the vertex function values.
98
99 self.xk will become the first point of the simplex.
100 """
101
102 self.simplex = zeros((self.m, self.n), Float64)
103 self.simplex_vals = zeros(self.m, Float64)
104
105 self.simplex[0] = self.xk
106 self.simplex_vals[0], self.f_count = apply(self.func, (self.xk,)+self.args), self.f_count + 1
107
108 for i in range(self.n):
109 j = i + 1
110 self.simplex[j] = self.xk
111 if self.xk[i] == 0.0:
112 self.simplex[j, i] = 2.5 * 1e-4
113 else:
114 self.simplex[j, i] = 1.05 * self.simplex[j, i]
115 self.simplex_vals[j], self.f_count = apply(self.func, (self.simplex[j],)+self.args), self.f_count + 1
116
117
118
120 "Extension step."
121
122 self.extend_vector = 3.0 * self.pivot_point - 2.0 * self.simplex[-1]
123 self.extend_val, self.f_count = apply(self.func, (self.extend_vector,)+self.args), self.f_count + 1
124 if self.extend_val < self.reflect_val:
125 self.simplex[-1], self.simplex_vals[-1] = self.extend_vector, self.extend_val
126 self.reflect_flag = 0
127
128
130 "Order the vertecies of the simplex according to accending function values."
131 sorted = argsort(self.simplex_vals)
132 self.simplex = take(self.simplex, sorted)
133 self.simplex_vals = take(self.simplex_vals, sorted)
134
135
137 "Reflection step."
138
139 self.reflect_vector = 2.0 * self.pivot_point - self.simplex[-1]
140 self.reflect_val, self.f_count = apply(self.func, (self.reflect_vector,)+self.args), self.f_count + 1
141
142
144 "Shrinking step."
145
146 for i in range(self.n):
147 j = i + 1
148 self.simplex[j] = 0.5 * (self.simplex[0] + self.simplex[j])
149 self.simplex_vals[j], self.f_count = apply(self.func, (self.simplex[j],)+self.args), self.f_count + 1
150
151
153
154 if abs(self.simplex_vals[-1] - self.simplex_vals[0]) < self.func_tol:
155 return 1
156