1   
  2   
  3   
  4   
  5   
  6   
  7   
  8   
  9   
 10   
 11   
 12   
 13   
 14   
 15   
 16   
 17   
 18   
 19   
 20   
 21   
 22   
 23   
 24  """Downhill simplex optimization. 
 25   
 26  This file is part of the U{minfx optimisation library<https://sourceforge.net/projects/minfx>}. 
 27  """ 
 28   
 29   
 30  from numpy import argsort, average, float64, sum, take, zeros 
 31   
 32   
 33  from minfx.base_classes import Min 
 34   
 35   
 36 -def simplex(func=None, args=(), x0=None, func_tol=1e-25, maxiter=1e6, full_output=0, print_flag=0, print_prefix=""): 
  37      """Downhill simplex minimisation.""" 
 38   
 39      if print_flag: 
 40          if print_flag >= 2: 
 41              print(print_prefix) 
 42          print(print_prefix) 
 43          print(print_prefix + "Simplex minimisation") 
 44          print(print_prefix + "~~~~~~~~~~~~~~~~~~~~") 
 45      min = Simplex(func, args, x0, func_tol, maxiter, full_output, print_flag, print_prefix) 
 46      results = min.minimise() 
 47      return results 
  48   
 49   
 51 -    def __init__(self, func, args, x0, func_tol, maxiter, full_output, print_flag, print_prefix): 
  52          """Class for downhill simplex minimisation specific functions. 
 53   
 54          Unless you know what you are doing, you should call the function 'simplex' rather than using this class. 
 55          """ 
 56   
 57           
 58          self.func = func 
 59          self.args = args 
 60          self.xk = x0 
 61          self.func_tol = func_tol 
 62          self.maxiter = maxiter 
 63          self.full_output = full_output 
 64          self.print_flag = print_flag 
 65          self.print_prefix = print_prefix 
 66   
 67           
 68          self.f_count = 0 
 69          self.g_count = 0 
 70          self.h_count = 0 
 71   
 72           
 73          self.warning = None 
 74   
 75           
 76          self.n = len(self.xk) 
 77          self.m = self.n + 1 
 78   
 79           
 80          self.simplex = zeros((self.m, self.n), float64) 
 81          self.simplex_vals = zeros(self.m, float64) 
 82   
 83          self.simplex[0] = self.xk * 1.0 
 84          self.simplex_vals[0], self.f_count = self.func(*(self.xk,)+self.args), self.f_count + 1 
 85   
 86          for i in range(self.n): 
 87              j = i + 1 
 88              self.simplex[j] = self.xk 
 89              if self.xk[i] == 0.0: 
 90                  self.simplex[j, i] = 2.5 * 1e-4 
 91              else: 
 92                  self.simplex[j, i] = 1.05 * self.simplex[j, i] 
 93              self.simplex_vals[j], self.f_count = self.func(*(self.simplex[j],)+self.args), self.f_count + 1 
 94   
 95           
 96          self.order_simplex() 
 97   
 98           
 99          self.xk = self.simplex[0] * 1.0 
100          self.fk = self.simplex_vals[0] 
101   
102           
103          self.center = average(self.simplex, axis=0) 
 104   
105   
107          """The new parameter function. 
108   
109          Simplex movement. 
110          """ 
111   
112          self.reflect_flag = 1 
113          self.shrink_flag = 0 
114   
115          self.pivot_point = average(self.simplex[:-1], axis=0) 
116   
117          self.reflect() 
118          if self.reflect_val <= self.simplex_vals[0]: 
119              self.extend() 
120          elif self.reflect_val >= self.simplex_vals[-2]: 
121              self.reflect_flag = 0 
122              if self.reflect_val < self.simplex_vals[-1]: 
123                  self.contract() 
124              else: 
125                  self.contract_orig() 
126          if self.reflect_flag: 
127              self.simplex[-1], self.simplex_vals[-1] = self.reflect_vector, self.reflect_val 
128          if self.shrink_flag: 
129              self.shrink() 
130   
131          self.order_simplex() 
132   
133           
134          self.xk_new = self.simplex[0] 
135          self.fk_new = self.simplex_vals[0] 
136          self.dfk_new = None 
137   
138           
139          self.center_new = average(self.simplex, axis=0) 
140          self.dist = sum(abs(self.center_new - self.center), axis=0) 
 141   
142   
144          """Contraction step.""" 
145   
146          self.contract_vector = 1.5 * self.pivot_point - 0.5 * self.simplex[-1] 
147          self.contract_val, self.f_count = self.func(*(self.contract_vector,)+self.args), self.f_count + 1 
148          if self.contract_val < self.reflect_val: 
149              self.simplex[-1], self.simplex_vals[-1] = self.contract_vector, self.contract_val 
150          else: 
151              self.shrink_flag = 1 
 152   
153   
155          """Contraction of the original simplex.""" 
156   
157          self.contract_orig_vector = 0.5 * (self.pivot_point + self.simplex[-1]) 
158          self.contract_orig_val, self.f_count = self.func(*(self.contract_orig_vector,)+self.args), self.f_count + 1 
159          if self.contract_orig_val < self.simplex_vals[-1]: 
160              self.simplex[-1], self.simplex_vals[-1] = self.contract_orig_vector, self.contract_orig_val 
161          else: 
162              self.shrink_flag = 1 
 163   
164   
166          """Extension step.""" 
167   
168          self.extend_vector = 3.0 * self.pivot_point - 2.0 * self.simplex[-1] 
169          self.extend_val, self.f_count = self.func(*(self.extend_vector,)+self.args), self.f_count + 1 
170          if self.extend_val < self.reflect_val: 
171              self.simplex[-1], self.simplex_vals[-1] = self.extend_vector, self.extend_val 
172              self.reflect_flag = 0 
 173   
174   
176          """Order the vertices of the simplex according to ascending function values.""" 
177   
178          sorted = argsort(self.simplex_vals) 
179          self.simplex = take(self.simplex, sorted, axis=0) 
180          self.simplex_vals = take(self.simplex_vals, sorted, axis=0) 
 181   
182   
184          """Reflection step.""" 
185   
186          self.reflect_vector = 2.0 * self.pivot_point - self.simplex[-1] 
187          self.reflect_val, self.f_count = self.func(*(self.reflect_vector,)+self.args), self.f_count + 1 
 188   
189   
191          """Shrinking step.""" 
192   
193          for i in range(self.n): 
194              j = i + 1 
195              self.simplex[j] = 0.5 * (self.simplex[0] + self.simplex[j]) 
196              self.simplex_vals[j], self.f_count = self.func(*(self.simplex[j],)+self.args), self.f_count + 1 
 197   
198   
200          """Convergence test. 
201   
202          Finish minimising when the function difference between the highest and lowest simplex vertices is insignificant or if the simplex doesn't move. 
203          """ 
204   
205          if self.print_flag >= 2: 
206              print(self.print_prefix + "diff = " + repr(self.simplex_vals[-1] - self.simplex_vals[0])) 
207              print(self.print_prefix + "|diff| = " + repr(abs(self.simplex_vals[-1] - self.simplex_vals[0]))) 
208              print(self.print_prefix + "f_tol = " + repr(self.func_tol)) 
209              print(self.print_prefix + "center = " + repr(self.pivot_point)) 
210              try: 
211                  print(self.print_prefix + "old center = " + repr(self.old_pivot)) 
212                  print(self.print_prefix + "center diff = " + repr(self.pivot_point - self.old_pivot)) 
213              except AttributeError: 
214                  pass 
215              self.old_pivot = 1.0 * self.pivot_point 
216          if abs(self.simplex_vals[-1] - self.simplex_vals[0]) <= self.func_tol: 
217              if self.print_flag >= 2: 
218                  print("\n" + self.print_prefix + "???Function tolerance reached.") 
219                  print(self.print_prefix + "simplex_vals[-1]: " + repr(self.simplex_vals[-1])) 
220                  print(self.print_prefix + "simplex_vals[0]:  " + repr(self.simplex_vals[0])) 
221                  print(self.print_prefix + "|diff|:           " + repr(abs(self.simplex_vals[-1] - self.simplex_vals[0]))) 
222                  print(self.print_prefix + "tol:              " + repr(self.func_tol)) 
223              self.xk_new = self.simplex[0] 
224              self.fk_new = self.simplex_vals[0] 
225              return 1 
226   
227           
228          if self.dist == 0.0: 
229              self.warning = "Simplex has not moved." 
230              self.xk_new = self.simplex[0] 
231              self.fk_new = self.simplex_vals[0] 
232              return 1 
 233   
234   
236          """Update function.""" 
237   
238          self.xk = self.xk_new * 1.0 
239          self.fk = self.fk_new 
240          self.center = self.center_new * 1.0