Package minfx :: Module simplex
[hide private]
[frames] | no frames]

Source Code for Module minfx.simplex

  1  ############################################################################### 
  2  #                                                                             # 
  3  # Copyright (C) 2003-2013 Edward d'Auvergne                                   # 
  4  #                                                                             # 
  5  # This file is part of the minfx optimisation library,                        # 
  6  # https://sourceforge.net/projects/minfx                                      # 
  7  #                                                                             # 
  8  # This program is free software: you can redistribute it and/or modify        # 
  9  # it under the terms of the GNU General Public License as published by        # 
 10  # the Free Software Foundation, either version 3 of the License, or           # 
 11  # (at your option) any later version.                                         # 
 12  #                                                                             # 
 13  # This program is distributed in the hope that it will be useful,             # 
 14  # but WITHOUT ANY WARRANTY; without even the implied warranty of              # 
 15  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               # 
 16  # GNU General Public License for more details.                                # 
 17  #                                                                             # 
 18  # You should have received a copy of the GNU General Public License           # 
 19  # along with this program.  If not, see <http://www.gnu.org/licenses/>.       # 
 20  #                                                                             # 
 21  ############################################################################### 
 22   
 23  # Module docstring. 
 24  """Downhill simplex optimization. 
 25   
 26  This file is part of the U{minfx optimisation library<https://sourceforge.net/projects/minfx>}. 
 27  """ 
 28   
 29  # Python module imports. 
 30  from numpy import argsort, average, float64, sum, take, zeros 
 31   
 32  # Minfx module imports. 
 33  from minfx.base_classes import Min 
 34   
 35   
36 -def simplex(func=None, args=(), x0=None, func_tol=1e-25, maxiter=1e6, full_output=0, print_flag=0, print_prefix=""):
37 """Downhill simplex minimisation.""" 38 39 if print_flag: 40 if print_flag >= 2: 41 print(print_prefix) 42 print(print_prefix) 43 print(print_prefix + "Simplex minimisation") 44 print(print_prefix + "~~~~~~~~~~~~~~~~~~~~") 45 min = Simplex(func, args, x0, func_tol, maxiter, full_output, print_flag, print_prefix) 46 results = min.minimise() 47 return results
48 49
50 -class Simplex(Min):
51 - def __init__(self, func, args, x0, func_tol, maxiter, full_output, print_flag, print_prefix):
52 """Class for downhill simplex minimisation specific functions. 53 54 Unless you know what you are doing, you should call the function 'simplex' rather than using this class. 55 """ 56 57 # Function arguments. 58 self.func = func 59 self.args = args 60 self.xk = x0 61 self.func_tol = func_tol 62 self.maxiter = maxiter 63 self.full_output = full_output 64 self.print_flag = print_flag 65 self.print_prefix = print_prefix 66 67 # Initialise the function, gradient, and Hessian evaluation counters. 68 self.f_count = 0 69 self.g_count = 0 70 self.h_count = 0 71 72 # Initialise the warning string. 73 self.warning = None 74 75 # Initialise some constants. 76 self.n = len(self.xk) 77 self.m = self.n + 1 78 79 # Create the simplex 80 self.simplex = zeros((self.m, self.n), float64) 81 self.simplex_vals = zeros(self.m, float64) 82 83 self.simplex[0] = self.xk * 1.0 84 self.simplex_vals[0], self.f_count = self.func(*(self.xk,)+self.args), self.f_count + 1 85 86 for i in range(self.n): 87 j = i + 1 88 self.simplex[j] = self.xk 89 if self.xk[i] == 0.0: 90 self.simplex[j, i] = 2.5 * 1e-4 91 else: 92 self.simplex[j, i] = 1.05 * self.simplex[j, i] 93 self.simplex_vals[j], self.f_count = self.func(*(self.simplex[j],)+self.args), self.f_count + 1 94 95 # Order the simplex. 96 self.order_simplex() 97 98 # Set xk and fk as the vertex of the simplex with the lowest function value. 99 self.xk = self.simplex[0] * 1.0 100 self.fk = self.simplex_vals[0] 101 102 # Find the center of the simplex. 103 self.center = average(self.simplex, axis=0)
104 105
106 - def new_param_func(self):
107 """The new parameter function. 108 109 Simplex movement. 110 """ 111 112 self.reflect_flag = 1 113 self.shrink_flag = 0 114 115 self.pivot_point = average(self.simplex[:-1], axis=0) 116 117 self.reflect() 118 if self.reflect_val <= self.simplex_vals[0]: 119 self.extend() 120 elif self.reflect_val >= self.simplex_vals[-2]: 121 self.reflect_flag = 0 122 if self.reflect_val < self.simplex_vals[-1]: 123 self.contract() 124 else: 125 self.contract_orig() 126 if self.reflect_flag: 127 self.simplex[-1], self.simplex_vals[-1] = self.reflect_vector, self.reflect_val 128 if self.shrink_flag: 129 self.shrink() 130 131 self.order_simplex() 132 133 # Update values. 134 self.xk_new = self.simplex[0] 135 self.fk_new = self.simplex_vals[0] 136 self.dfk_new = None 137 138 # Find the center of the simplex and calculate the distance moved. 139 self.center_new = average(self.simplex, axis=0) 140 self.dist = sum(abs(self.center_new - self.center), axis=0)
141 142
143 - def contract(self):
144 """Contraction step.""" 145 146 self.contract_vector = 1.5 * self.pivot_point - 0.5 * self.simplex[-1] 147 self.contract_val, self.f_count = self.func(*(self.contract_vector,)+self.args), self.f_count + 1 148 if self.contract_val < self.reflect_val: 149 self.simplex[-1], self.simplex_vals[-1] = self.contract_vector, self.contract_val 150 else: 151 self.shrink_flag = 1
152 153
154 - def contract_orig(self):
155 """Contraction of the original simplex.""" 156 157 self.contract_orig_vector = 0.5 * (self.pivot_point + self.simplex[-1]) 158 self.contract_orig_val, self.f_count = self.func(*(self.contract_orig_vector,)+self.args), self.f_count + 1 159 if self.contract_orig_val < self.simplex_vals[-1]: 160 self.simplex[-1], self.simplex_vals[-1] = self.contract_orig_vector, self.contract_orig_val 161 else: 162 self.shrink_flag = 1
163 164
165 - def extend(self):
166 """Extension step.""" 167 168 self.extend_vector = 3.0 * self.pivot_point - 2.0 * self.simplex[-1] 169 self.extend_val, self.f_count = self.func(*(self.extend_vector,)+self.args), self.f_count + 1 170 if self.extend_val < self.reflect_val: 171 self.simplex[-1], self.simplex_vals[-1] = self.extend_vector, self.extend_val 172 self.reflect_flag = 0
173 174
175 - def order_simplex(self):
176 """Order the vertices of the simplex according to ascending function values.""" 177 178 sorted = argsort(self.simplex_vals) 179 self.simplex = take(self.simplex, sorted, axis=0) 180 self.simplex_vals = take(self.simplex_vals, sorted, axis=0)
181 182
183 - def reflect(self):
184 """Reflection step.""" 185 186 self.reflect_vector = 2.0 * self.pivot_point - self.simplex[-1] 187 self.reflect_val, self.f_count = self.func(*(self.reflect_vector,)+self.args), self.f_count + 1
188 189
190 - def shrink(self):
191 """Shrinking step.""" 192 193 for i in range(self.n): 194 j = i + 1 195 self.simplex[j] = 0.5 * (self.simplex[0] + self.simplex[j]) 196 self.simplex_vals[j], self.f_count = self.func(*(self.simplex[j],)+self.args), self.f_count + 1
197 198
199 - def conv_test(self, *args):
200 """Convergence test. 201 202 Finish minimising when the function difference between the highest and lowest simplex vertices is insignificant or if the simplex doesn't move. 203 """ 204 205 if self.print_flag >= 2: 206 print(self.print_prefix + "diff = " + repr(self.simplex_vals[-1] - self.simplex_vals[0])) 207 print(self.print_prefix + "|diff| = " + repr(abs(self.simplex_vals[-1] - self.simplex_vals[0]))) 208 print(self.print_prefix + "f_tol = " + repr(self.func_tol)) 209 print(self.print_prefix + "center = " + repr(self.pivot_point)) 210 try: 211 print(self.print_prefix + "old center = " + repr(self.old_pivot)) 212 print(self.print_prefix + "center diff = " + repr(self.pivot_point - self.old_pivot)) 213 except AttributeError: 214 pass 215 self.old_pivot = 1.0 * self.pivot_point 216 if abs(self.simplex_vals[-1] - self.simplex_vals[0]) <= self.func_tol: 217 if self.print_flag >= 2: 218 print("\n" + self.print_prefix + "???Function tolerance reached.") 219 print(self.print_prefix + "simplex_vals[-1]: " + repr(self.simplex_vals[-1])) 220 print(self.print_prefix + "simplex_vals[0]: " + repr(self.simplex_vals[0])) 221 print(self.print_prefix + "|diff|: " + repr(abs(self.simplex_vals[-1] - self.simplex_vals[0]))) 222 print(self.print_prefix + "tol: " + repr(self.func_tol)) 223 self.xk_new = self.simplex[0] 224 self.fk_new = self.simplex_vals[0] 225 return 1 226 227 # Test if simplex has not moved. 228 if self.dist == 0.0: 229 self.warning = "Simplex has not moved." 230 self.xk_new = self.simplex[0] 231 self.fk_new = self.simplex_vals[0] 232 return 1
233 234
235 - def update(self):
236 """Update function.""" 237 238 self.xk = self.xk_new * 1.0 239 self.fk = self.fk_new 240 self.center = self.center_new * 1.0
241