Package minfx :: Module base_classes
[hide private]
[frames] | no frames]

Source Code for Module minfx.base_classes

  1  ############################################################################### 
  2  #                                                                             # 
  3  # Copyright (C) 2003-2014 Edward d'Auvergne                                   # 
  4  #                                                                             # 
  5  # This file is part of the minfx optimisation library,                        # 
  6  # https://gna.org/projects/minfx                                              # 
  7  #                                                                             # 
  8  # This program is free software: you can redistribute it and/or modify        # 
  9  # it under the terms of the GNU General Public License as published by        # 
 10  # the Free Software Foundation, either version 3 of the License, or           # 
 11  # (at your option) any later version.                                         # 
 12  #                                                                             # 
 13  # This program is distributed in the hope that it will be useful,             # 
 14  # but WITHOUT ANY WARRANTY; without even the implied warranty of              # 
 15  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               # 
 16  # GNU General Public License for more details.                                # 
 17  #                                                                             # 
 18  # You should have received a copy of the GNU General Public License           # 
 19  # along with this program.  If not, see <http://www.gnu.org/licenses/>.       # 
 20  #                                                                             # 
 21  ############################################################################### 
 22   
 23  # Module docstring. 
 24  """Bases classes for the U{minfx optimisation library<https://gna.org/projects/minfx>}. 
 25   
 26  This module contains the following base classes: 
 27      - Min:                The base class containing the main iterative minimisation loop and a few other base class functions. 
 28      - Line_search:        The base class containing the generic line search functions. 
 29      - Trust_region:       The base class containing the generic trust-region functions. 
 30      - Conjugate_gradient: The base class containing the generic conjugate gradient functions. 
 31  """ 
 32   
 33  # Inbuilt python modules. 
 34  ######################### 
 35   
 36  from numpy import dot, inf, sqrt 
 37  from numpy.linalg import inv, LinAlgError 
 38  from re import match 
 39  import sys 
 40   
 41   
 42  # Line search functions. 
 43  ######################## 
 44   
 45  from minfx.line_search.backtrack import backtrack 
 46  from minfx.line_search.nocedal_wright_interpol import nocedal_wright_interpol 
 47  from minfx.line_search.nocedal_wright_wolfe import nocedal_wright_wolfe 
 48  from minfx.line_search.more_thuente import more_thuente 
 49   
 50   
 51  # Hessian modifications. 
 52  ######################## 
 53   
 54  from minfx.hessian_mods.cholesky_mod import cholesky_mod 
 55  from minfx.hessian_mods.eigenvalue import eigenvalue 
 56  from minfx.hessian_mods.gmw81 import gmw 
 57  from minfx.hessian_mods.gmw81_old import gmw_old 
 58  from minfx.hessian_mods.se99 import se99 
 59   
 60   
 90   
 91   
 92  # The generic minimisation base class (containing the main iterative loop). 
 93  ########################################################################### 
 94   
95 -class Min:
96 - def __init__(self):
97 """Base class containing the main minimisation iterative loop algorithm. 98 99 The algorithm is defined in the minimise function. Also supplied are generic setup, convergence tests, and update functions. 100 """
101 102
103 - def double_test(self, fk_new, fk, gk):
104 """Default base class function for both function and gradient convergence tests. 105 106 Test if the minimum function tolerance between fk and fk+1 has been reached as well as if the minimum gradient tolerance has been reached. 107 """ 108 109 # Test the function tolerance. 110 if abs(fk_new - fk) <= self.func_tol: 111 if self.print_flag >= 2: 112 print("\n" + self.print_prefix + "Function tolerance reached.") 113 print(self.print_prefix + "fk: " + repr(fk)) 114 print(self.print_prefix + "fk+1: " + repr(fk_new)) 115 print(self.print_prefix + "|fk+1 - fk|: " + repr(abs(fk_new - fk))) 116 print(self.print_prefix + "tol: " + repr(self.func_tol)) 117 return 1 118 119 # Test the gradient tolerance. 120 elif sqrt(dot(gk, gk)) <= self.grad_tol: 121 if self.print_flag >= 2: 122 print("\n" + self.print_prefix + "Gradient tolerance reached.") 123 print(self.print_prefix + "gk+1: " + repr(gk)) 124 print(self.print_prefix + "||gk+1||: " + repr(sqrt(dot(gk, gk)))) 125 print(self.print_prefix + "tol: " + repr(self.grad_tol)) 126 return 1
127 128
129 - def func_test(self, fk_new, fk, gk=None):
130 """Default base class function for the function convergence test. 131 132 Test if the minimum function tolerance between fk and fk+1 has been reached. 133 """ 134 135 # Test the function tolerance. 136 if abs(fk_new - fk) <= self.func_tol: 137 if self.print_flag >= 2: 138 print("\n" + self.print_prefix + "Function tolerance reached.") 139 print(self.print_prefix + "fk: " + repr(fk)) 140 print(self.print_prefix + "fk+1: " + repr(fk_new)) 141 print(self.print_prefix + "|fk+1 - fk|: " + repr(abs(fk_new - fk))) 142 print(self.print_prefix + "tol: " + repr(self.func_tol)) 143 return 1
144 145
146 - def grad_test(self, fk_new, fk, gk):
147 """Default base class function for the gradient convergence test. 148 149 Test if the minimum gradient tolerance has been reached. Minimisation will also terminate if the function value difference between fk and fk+1 is zero. This modification is essential for the quasi-Newton techniques. 150 """ 151 152 # Test the gradient tolerance. 153 if sqrt(dot(gk, gk)) <= self.grad_tol: 154 if self.print_flag >= 2: 155 print("\n" + self.print_prefix + "Gradient tolerance reached.") 156 print(self.print_prefix + "gk+1: " + repr(gk)) 157 print(self.print_prefix + "||gk+1||: " + repr(sqrt(dot(gk, gk)))) 158 print(self.print_prefix + "tol: " + repr(self.grad_tol)) 159 return 1 160 161 # No change in function value (prevents the minimiser from iterating without moving). 162 elif fk_new - fk == 0.0: 163 if self.print_flag >= 2: 164 print("\n" + self.print_prefix + "Function difference of zero.") 165 print(self.print_prefix + "fk_new - fk = 0.0") 166 return 1
167 168
169 - def hessian_type_and_mod(self, min_options, default_type='Newton', default_mod='GMW'):
170 """Hessian type and modification options. 171 172 Function for sorting out the minimisation options when either the Hessian type or Hessian modification can be selected. 173 """ 174 175 # Initialise. 176 self.hessian_type = None 177 self.hessian_mod = None 178 179 # Test if the options are a tuple. 180 if not isinstance(min_options, tuple): 181 print(self.print_prefix + "The minimisation options " + repr(min_options) + " is not a tuple.") 182 self.init_failure = 1 183 return 184 185 # Test that no more thant 2 options are given. 186 if len(min_options) > 2: 187 print(self.print_prefix + "A maximum of two minimisation options is allowed (the Hessian type and Hessian modification).") 188 self.init_failure = 1 189 return 190 191 # Sort out the minimisation options. 192 for opt in min_options: 193 if self.hessian_type == None and opt != None and (match('[Bb][Ff][Gg][Ss]', opt) or match('[Nn]ewton', opt)): 194 self.hessian_type = opt 195 elif self.hessian_mod == None and self.valid_hessian_mod(opt): 196 self.hessian_mod = opt 197 else: 198 print(self.print_prefix + "The minimisation option " + repr(opt) + " from " + repr(min_options) + " is neither a valid Hessian type nor modification.") 199 self.init_failure = 1 200 return 201 202 # Default Hessian type. 203 if self.hessian_type == None: 204 self.hessian_type = default_type 205 206 # Make sure that no Hessian modification is used with the BFGS matrix. 207 if match('[Bb][Ff][Gg][Ss]', self.hessian_type) and self.hessian_mod != None: 208 print(self.print_prefix + "When using the BFGS matrix, Hessian modifications should not be used.") 209 self.init_failure = 1 210 return 211 212 # Default Hessian modification when the Hessian type is Newton. 213 if match('[Nn]ewton', self.hessian_type) and self.hessian_mod == None: 214 self.hessian_mod = None 215 #self.hessian_mod = default_mod 216 217 # Print the Hessian type info. 218 if self.print_flag: 219 if match('[Bb][Ff][Gg][Ss]', self.hessian_type): 220 print(self.print_prefix + "Hessian type: BFGS") 221 else: 222 print(self.print_prefix + "Hessian type: Newton")
223 224
225 - def minimise(self):
226 """Main minimisation iterative loop algorithm. 227 228 This algorithm is designed to be compatible with all iterative minimisers. The outline is: 229 230 - k = 0 231 - while 1: 232 - New parameter function 233 - Convergence tests 234 - Update function 235 - k = k + 1 236 """ 237 238 # Start the iteration counter. 239 self.k = 0 240 if self.print_flag: 241 self.k2 = 0 242 print("") # Print a new line. 243 244 # Iterate until the local minima is found. 245 while True: 246 # Print out. 247 if self.print_flag: 248 out = 0 249 if self.print_flag >= 2: 250 print("\n" + self.print_prefix + "Main iteration k=" + repr(self.k)) 251 out = 1 252 else: 253 if self.k2 == 100: 254 self.k2 = 0 255 if self.k2 == 0: 256 out = 1 257 if out == 1: 258 print_iter(self.k, self.xk, self.fk, print_prefix=self.print_prefix) 259 260 # Get xk+1 (new parameter function). 261 try: 262 self.new_param_func() 263 except LinAlgError: 264 message = sys.exc_info()[1] 265 if isinstance(message.args[0], int): 266 text = message.args[1] 267 else: 268 text = message.args[0] 269 self.warning = "LinAlgError: " + text + " (fatal minimisation error)." 270 break 271 except OverflowError: 272 message = sys.exc_info()[1] 273 if isinstance(message.args[0], int): 274 text = message.args[1] 275 else: 276 text = message.args[0] 277 self.warning = "OverflowError: " + text + " (fatal minimisation error)." 278 break 279 except NameError: 280 message = sys.exc_info()[1] 281 self.warning = message.args[0] + " (fatal minimisation error)." 282 break 283 284 # Test for warnings. 285 if self.warning != None: 286 break 287 288 # Maximum number of iteration test. 289 if self.k >= self.maxiter - 1: 290 self.warning = "Maximum number of iterations reached" 291 break 292 293 # Convergence test. 294 if self.conv_test(self.fk_new, self.fk, self.dfk_new): 295 break 296 297 # Infinite function value. 298 if self.fk_new == inf: 299 self.warning = "Infinite function value encountered, can no longer perform optimisation." 300 break 301 302 # Update function. 303 try: 304 self.update() 305 except OverflowError: 306 message = sys.exc_info()[1] 307 if isinstance(message.args[0], int): 308 text = message.args[1] 309 else: 310 text = message.args[0] 311 self.warning = "OverflowError: " + text + " (fatal minimisation error)." 312 break 313 except NameError: 314 message = sys.exc_info()[1] 315 if isinstance(message.args[0], int): 316 self.warning = message.args[1] 317 else: 318 self.warning = message.args[0] 319 break 320 321 # Iteration counter update. 322 self.k = self.k + 1 323 if self.print_flag: 324 self.k2 = self.k2 + 1 325 326 if self.full_output: 327 try: 328 return self.xk_new, self.fk_new, self.k+1, self.f_count, self.g_count, self.h_count, self.warning 329 except AttributeError: 330 return self.xk, self.fk, self.k, self.f_count, self.g_count, self.h_count, self.warning 331 else: 332 try: 333 return self.xk_new 334 except AttributeError: 335 return self.xk
336 337
338 - def setup_conv_tests(self):
339 """Default base class for selecting the convergence tests.""" 340 341 if self.func_tol != None and self.grad_tol != None: 342 self.conv_test = self.double_test 343 elif self.func_tol != None: 344 self.conv_test = self.func_test 345 elif self.grad_tol != None: 346 self.conv_test = self.grad_test 347 else: 348 print(self.print_prefix + "Convergence tests cannot be setup because both func_tol and grad_tol are set to None.") 349 self.init_failure = 1 350 return
351 352
353 - def update(self):
354 """Default base class update function. 355 356 xk+1 is shifted to xk 357 fk+1 is shifted to fk 358 """ 359 360 self.xk = self.xk_new * 1.0 361 self.fk = self.fk_new
362 363 364 365 366 367 # The base class containing the generic line search functions. 368 ############################################################## 369
370 -class Line_search:
371 - def __init__(self):
372 """Base class containing the generic line search functions."""
373 374
375 - def backline(self):
376 """Function for running the backtracking line search.""" 377 378 self.alpha, fc = backtrack(self.func, self.args, self.xk, self.fk, self.dfk, self.pk, a_init=self.a0) 379 self.f_count = self.f_count + fc
380 381
382 - def line_search_options(self, min_options):
383 """Line search options. 384 385 Function for sorting out the minimisation options when the only option can be a line search. 386 """ 387 388 # Initialise. 389 self.line_search_algor = None 390 391 # Test if the options are a tuple. 392 if not isinstance(min_options, tuple): 393 print(self.print_prefix + "The minimisation options " + repr(min_options) + " is not a tuple.") 394 self.init_failure = 1 395 return 396 397 # No more than one option is allowed. 398 if len(min_options) > 1: 399 print(self.print_prefix + "A maximum of one minimisation options is allowed (the line search algorithm).") 400 self.init_failure = 1 401 return 402 403 # Sort out the minimisation options. 404 for opt in min_options: 405 if self.valid_line_search(opt): 406 self.line_search_algor = opt 407 else: 408 print(self.print_prefix + "The minimisation option " + repr(opt) + " from " + repr(min_options) + " is not a valid line search algorithm.") 409 self.init_failure = 1 410 return 411 412 # Default line search algorithm. 413 if self.line_search_algor == None: 414 self.line_search_algor = 'Back'
415 416
417 - def mt(self):
418 """Function for running the More and Thuente line search.""" 419 420 self.alpha, fc, gc = more_thuente(self.func, self.dfunc, self.args, self.xk, self.fk, self.dfk, self.pk, a_init=self.a0, mu=self.mu, eta=self.eta, print_flag=0) 421 self.f_count = self.f_count + fc 422 self.g_count = self.g_count + gc
423 424
425 - def no_search(self):
426 """Set alpha to alpha0.""" 427 428 self.alpha = self.a0
429 430
431 - def nwi(self):
432 """Function for running the Nocedal and Wright interpolation based line search.""" 433 434 self.alpha, fc = nocedal_wright_interpol(self.func, self.args, self.xk, self.fk, self.dfk, self.pk, a_init=self.a0, mu=self.mu, print_flag=0) 435 self.f_count = self.f_count + fc
436 437
438 - def nww(self):
439 """Function for running the Nocedal and Wright line search for the Wolfe conditions.""" 440 441 self.alpha, fc, gc = nocedal_wright_wolfe(self.func, self.dfunc, self.args, self.xk, self.fk, self.dfk, self.pk, a_init=self.a0, mu=self.mu, eta=self.eta, print_flag=0) 442 self.f_count = self.f_count + fc 443 self.g_count = self.g_count + gc
444 445
446 - def setup_line_search(self):
447 """The line search function.""" 448 449 if self.line_search_algor == None: 450 self.init_failure = 1 451 return 452 elif match('^[Bb]ack', self.line_search_algor): 453 if self.print_flag: 454 print(self.print_prefix + "Line search: Backtracking line search.") 455 self.line_search = self.backline 456 elif match('^[Nn]ocedal[ _][Ww]right[ _][Ii]nt', self.line_search_algor) or match('^[Nn][Ww][Ii]', self.line_search_algor): 457 if self.print_flag: 458 print(self.print_prefix + "Line search: Nocedal and Wright interpolation based line search.") 459 self.line_search = self.nwi 460 elif match('^[Nn]ocedal[ _][Ww]right[ _][Ww]olfe', self.line_search_algor) or match('^[Nn][Ww][Ww]', self.line_search_algor): 461 if self.print_flag: 462 print(self.print_prefix + "Line search: Nocedal and Wright line search for the Wolfe conditions.") 463 self.line_search = self.nww 464 elif match('^[Mm]ore[ _][Tt]huente$', self.line_search_algor) or match('^[Mm][Tt]', self.line_search_algor): 465 if self.print_flag: 466 print(self.print_prefix + "Line search: More and Thuente line search.") 467 self.line_search = self.mt 468 elif match('^[Nn]o [Ll]ine [Ss]earch$', self.line_search_algor): 469 if self.print_flag: 470 print(self.print_prefix + "Line search: No line search.") 471 self.line_search = self.no_search
472 473
474 - def valid_line_search(self, type):
475 """Test if the string 'type' is a valid line search algorithm.""" 476 477 if type == None: 478 return 0 479 elif match('^[Bb]ack', type) or match('^[Nn]ocedal[ _][Ww]right[ _][Ii]nt', type) or match('^[Nn][Ww][Ii]', type) or match('^[Nn]ocedal[ _][Ww]right[ _][Ww]olfe', type) or match('^[Nn][Ww][Ww]', type) or match('^[Mm]ore[ _][Tt]huente$', type) or match('^[Mm][Tt]', type) or match('^[Nn]o [Ll]ine [Ss]earch$', type): 480 return 1 481 else: 482 return 0
483 484 485 486 487 488 # The base class containing the generic trust-region functions. 489 ############################################################### 490
491 -class Trust_region:
492 - def __init__(self):
493 """Base class containing the generic trust-region functions."""
494 495
496 - def trust_region_update(self):
497 """An algorithm for trust region radius selection. 498 499 Page 68 from 'Numerical Optimization' by Jorge Nocedal and Stephen J. Wright, 1999, 2nd ed. 500 501 First calculate rho using the formula:: 502 503 f(xk) - f(xk + pk) 504 rho = ------------------, 505 mk(0) - mk(pk) 506 507 where the numerator is called the actual reduction and the denominator is the predicted reduction. Secondly choose the trust region radius for the next iteration. Finally decide if xk+1 should be shifted to xk. 508 """ 509 510 # Actual reduction. 511 act_red = self.fk - self.fk_new 512 513 # Predicted reduction. 514 pred_red = - dot(self.dfk, self.pk) - 0.5 * dot(self.pk, dot(self.d2fk, self.pk)) 515 516 # Rho. 517 if pred_red == 0.0: 518 self.rho = 1e99 519 else: 520 self.rho = act_red / pred_red 521 522 # Calculate the Euclidean norm of pk. 523 self.norm_pk = sqrt(dot(self.pk, self.pk)) 524 525 if self.print_flag >= 2: 526 print(self.print_prefix + "Actual reduction: " + repr(act_red)) 527 print(self.print_prefix + "Predicted reduction: " + repr(pred_red)) 528 print(self.print_prefix + "rho: " + repr(self.rho)) 529 print(self.print_prefix + "||pk||: " + repr(self.norm_pk)) 530 531 # Rho is close to zero or negative, therefore the trust region is shrunk. 532 if self.rho < 0.25 or pred_red < 0.0: 533 self.delta = 0.25 * self.delta 534 if self.print_flag >= 2: 535 print(self.print_prefix + "Shrinking the trust region.") 536 537 # Rho is close to one and pk has reached the boundary of the trust region, therefore the trust region is expanded. 538 elif self.rho > 0.75 and abs(self.norm_pk - self.delta) < 1e-5: 539 self.delta = min(2.0*self.delta, self.delta_max) 540 if self.print_flag >= 2: 541 print(self.print_prefix + "Expanding the trust region.") 542 543 # Rho is positive but not close to one, therefore the trust region is unaltered. 544 else: 545 if self.print_flag >= 2: 546 print(self.print_prefix + "Trust region is unaltered.") 547 548 if self.print_flag >= 2: 549 print(self.print_prefix + "New trust region: " + repr(self.delta)) 550 551 # Choose the position for the next iteration. 552 if self.rho > self.eta and pred_red > 0.0: 553 self.shift_flag = 1 554 if self.print_flag >= 2: 555 print(self.print_prefix + "rho > eta, " + repr(self.rho) + " > " + repr(self.eta)) 556 print(self.print_prefix + "Moving to, self.xk_new: " + repr(self.xk_new)) 557 else: 558 self.shift_flag = 0 559 if self.print_flag >= 2: 560 print(self.print_prefix + "rho < eta, " + repr(self.rho) + " < " + repr(self.eta)) 561 print(self.print_prefix + "Not moving, self.xk: " + repr(self.xk))
562 563 564 565 566 567 # The base class containing the generic conjugate gradient functions. 568 ##################################################################### 569
570 -class Conjugate_gradient:
571 - def __init__(self):
572 """Class containing the non-specific conjugate gradient code."""
573 574
575 - def new_param_func(self):
576 """The new parameter function. 577 578 Do a line search then calculate xk+1, fk+1, and gk+1. 579 """ 580 581 # Line search. 582 self.line_search() 583 584 # Find the new parameter vector and function value at that point. 585 self.xk_new = self.xk + self.alpha * self.pk 586 self.fk_new, self.f_count = self.func(*(self.xk_new,)+self.args), self.f_count + 1 587 self.dfk_new, self.g_count = self.dfunc(*(self.xk_new,)+self.args), self.g_count + 1 588 589 if self.print_flag >= 2: 590 print(self.print_prefix + "New param func:") 591 print(self.print_prefix + "\ta: " + repr(self.alpha)) 592 print(self.print_prefix + "\tpk: " + repr(self.pk)) 593 print(self.print_prefix + "\txk: " + repr(self.xk)) 594 print(self.print_prefix + "\txk+1: " + repr(self.xk_new)) 595 print(self.print_prefix + "\tfk: " + repr(self.fk)) 596 print(self.print_prefix + "\tfk+1: " + repr(self.fk_new)) 597 print(self.print_prefix + "\tgk: " + repr(self.dfk)) 598 print(self.print_prefix + "\tgk+1: " + repr(self.dfk_new))
599 600
601 - def old_cg_conv_test(self):
602 """Convergence tests. 603 604 This is old code implementing the conjugate gradient convergence test given on page 124 of 'Numerical Optimization' by Jorge Nocedal and Stephen J. Wright, 1999, 2nd ed. This function is currently unused. 605 """ 606 607 inf_norm = 0.0 608 for i in range(len(self.dfk)): 609 inf_norm = max(inf_norm, abs(self.dfk[i])) 610 if inf_norm < self.grad_tol * (1.0 + abs(self.fk)): 611 return 1 612 elif self.fk_new - self.fk == 0.0: 613 self.warning = "Function tol of zero reached." 614 return 1
615 616
617 - def update(self):
618 """Function to update the function value, gradient vector, and Hessian matrix""" 619 620 # Gradient dot product at k+1. 621 self.dot_dfk_new = dot(self.dfk_new, self.dfk_new) 622 623 # Calculate beta at k+1. 624 bk_new = self.calc_bk() 625 626 # Restarts. 627 if abs(dot(self.dfk_new, self.dfk)) / self.dot_dfk_new >= 0.1: 628 if self.print_flag >= 2: 629 print(self.print_prefix + "Restarting.") 630 bk_new = 0 631 632 # Calculate pk+1. 633 self.pk_new = -self.dfk_new + bk_new * self.pk 634 635 if self.print_flag >= 2: 636 print(self.print_prefix + "Update func:") 637 print(self.print_prefix + "\tpk: " + repr(self.pk)) 638 print(self.print_prefix + "\tpk+1: " + repr(self.pk_new)) 639 640 # Update. 641 self.xk = self.xk_new * 1.0 642 self.fk = self.fk_new 643 self.dfk = self.dfk_new * 1.0 644 self.pk = self.pk_new * 1.0 645 self.dot_dfk = self.dot_dfk_new
646 647 648 649 650 # The base class containing the Hessian modifications. 651 ###################################################### 652
653 -class Hessian_mods:
654 - def __init__(self):
655 """Base class containing the generic line search functions."""
656 657
658 - def cholesky_mod(self, return_matrix=0):
659 """Function for running the Cholesky Hessian modification.""" 660 661 return cholesky_mod(self.dfk, self.d2fk, self.I, self.n, self.print_prefix, self.print_flag, return_matrix)
662 663
664 - def eigenvalue(self, return_matrix=0):
665 """Function for running the eigenvalue Hessian modification.""" 666 667 return eigenvalue(self.dfk, self.d2fk, self.I, self.print_prefix, self.print_flag, return_matrix)
668 669
670 - def gmw(self, return_matrix=0):
671 """Function for running the Gill, Murray, and Wright modified Cholesky algorithm.""" 672 673 return gmw(self.dfk, self.d2fk, self.I, self.n, self.mach_acc, self.print_prefix, self.print_flag, return_matrix)
674 675
676 - def gmw_old(self, return_matrix=0):
677 """Function for running the Gill, Murray, and Wright modified Cholesky algorithm.""" 678 679 return gmw_old(self.dfk, self.d2fk, self.I, self.n, self.mach_acc, self.print_prefix, self.print_flag, return_matrix)
680 681
682 - def se99(self, return_matrix=0):
683 """Function for running the Gill, Murray, and Wright modified Cholesky algorithm.""" 684 685 return se99(self.dfk, self.d2fk, self.I, self.n, self.tau, self.tau_bar, self.mu, self.print_prefix, self.print_flag, return_matrix)
686 687
688 - def setup_hessian_mod(self):
689 """Initialise the Hessian modification functions.""" 690 691 # Unmodified Hessian. 692 if self.hessian_mod == None or match('^[Nn]o [Hh]essian [Mm]od', self.hessian_mod): 693 if self.print_flag: 694 print(self.print_prefix + "Hessian modification: Unmodified Hessian.") 695 self.get_pk = self.unmodified_hessian 696 697 # Eigenvalue modification. 698 elif match('^[Ee]igen', self.hessian_mod): 699 if self.print_flag: 700 print(self.print_prefix + "Hessian modification: Eigenvalue modification.") 701 self.get_pk = self.eigenvalue 702 703 # Cholesky with added multiple of the identity. 704 elif match('^[Cc]hol', self.hessian_mod): 705 if self.print_flag: 706 print(self.print_prefix + "Hessian modification: Cholesky with added multiple of the identity.") 707 self.get_pk = self.cholesky_mod 708 709 # The Gill, Murray, and Wright modified Cholesky algorithm. 710 elif match('^[Gg][Mm][Ww]$', self.hessian_mod): 711 if self.print_flag: 712 print(self.print_prefix + "Hessian modification: The Gill, Murray, and Wright modified Cholesky algorithm.") 713 self.get_pk = self.gmw 714 715 # The Gill, Murray, and Wright modified Cholesky algorithm. 716 elif match('^[Gg][Mm][Ww][ -_]old', self.hessian_mod): 717 if self.print_flag: 718 print(self.print_prefix + "Hessian modification: The Gill, Murray, and Wright modified Cholesky algorithm.") 719 self.get_pk = self.gmw_old 720 721 # The revised modified cholesky factorisation algorithm of Schnabel and Eskow, 99. 722 elif match('^[Ss][Ee]99', self.hessian_mod): 723 if self.print_flag: 724 print(self.print_prefix + "Hessian modification: The Schnabel and Eskow 1999 algorithm.") 725 self.tau = self.mach_acc ** (1.0/3.0) 726 self.tau_bar = self.mach_acc ** (2.0/3.0) 727 self.mu = 0.1 728 self.get_pk = self.se99
729 730
731 - def unmodified_hessian(self, return_matrix=0):
732 """Calculate the pure Newton direction.""" 733 734 if return_matrix: 735 return -dot(inv(self.d2fk), self.dfk), self.d2fk 736 else: 737 return -dot(inv(self.d2fk), self.dfk)
738 739
740 - def valid_hessian_mod(self, mod):
741 """Test if the string 'mod' is a valid Hessian modification.""" 742 743 if mod == None or match('^[Ee]igen', mod) or match('^[Cc]hol', mod) or match('^[Gg][Mm][Ww]$', mod) or match('^[Gg][Mm][Ww][ -_]old', mod) or match('^[Ss][Ee]99', mod) or match('^[Nn]o [Hh]essian [Mm]od', mod): 744 return 1 745 else: 746 return 0
747