Package model_selection :: Module cross_validation
[hide private]
[frames] | no frames]

Source Code for Module model_selection.cross_validation

  1  # A method based on cross-validation model selection. 
  2  # 
  3  # This implements one-item-out cross-validation. 
  4  # 
  5  # The program is divided into the following stages: 
  6  #       Stage 1:  Creation of the files for the model-free calculations for models 1 to 5.  For each model, 
  7  #               a directory for each relaxation data set is created without including the data.  Monte Carlo 
  8  #               simulations are not used on these initial runs, because the errors are not needed (should 
  9  #               speed up analysis considerably). 
 10  #       Stage 2:  Model selection and the creation of the final run.  Monte Carlo simulations are used to 
 11  #               find errors.  This stage has the option of optimizing the diffusion tensor along with the 
 12  #               model-free parameters. 
 13  #       Stage 3:  Extraction of the data. 
 14   
 15  import sys 
 16  from re import match 
 17   
 18  from common_ops import common_operations 
 19   
 20   
21 -class cv(common_operations):
22 - def __init__(self, mf):
23 "Model-free analysis based on cross-validation model selection methods." 24 25 self.mf = mf
26 27
28 - def extract_mf_data(self):
29 "Extract the model-free results." 30 31 for model in self.mf.data.runs: 32 print "Extracting model-free data of model " + model 33 for i in range(self.mf.data.num_ri): 34 cv_dir = model + "/" + model + "-" + self.mf.data.frq_label[self.mf.data.remap_table[i]][1] + "_" + self.mf.data.data_types[i] 35 cv_model = model + "-" + self.mf.data.frq_label[self.mf.data.remap_table[i]][1] + "_" + self.mf.data.data_types[i] 36 print "\t" + cv_dir + "/mfout." 37 mfout = self.mf.file_ops.read_file(cv_dir + '/mfout') 38 mfout_lines = mfout.readlines() 39 mfout.close() 40 num_res = len(self.mf.data.relax_data[0]) 41 self.mf.data.data[cv_model] = self.mf.star.extract(mfout_lines, num_res, self.mf.usr_param.chi2_lim, self.mf.usr_param.ftest_lim, ftest='n', sims='n')
42 43
44 - def fill_results(self, data, model='0'):
45 "Initialise the next row of the results data structure." 46 47 results = {} 48 results['res_num'] = data['res_num'] 49 results['model'] = model 50 results['s2'] = '' 51 results['s2_err'] = '' 52 results['s2f'] = '' 53 results['s2f_err'] = '' 54 results['s2s'] = '' 55 results['s2s_err'] = '' 56 results['te'] = '' 57 results['te_err'] = '' 58 results['rex'] = '' 59 results['rex_err'] = '' 60 results['chi2'] = '' 61 return results
62 63
64 - def model_selection(self):
65 "Model selection." 66 67 data = self.mf.data.data 68 self.mf.data.calc_frq() 69 self.mf.data.calc_constants() 70 tm = float(self.mf.usr_param.tm['val']) * 1e-9 71 72 if self.mf.debug: 73 self.mf.log.write("\n\n<<< " + self.mf.usr_param.method + " model selection >>>\n\n") 74 75 for res in range(len(self.mf.data.relax_data[0])): 76 sys.stdout.write("%9s" % "Residue: ") 77 sys.stdout.write("%-9s" % (self.mf.data.relax_data[0][res][1] + " " + self.mf.data.relax_data[0][res][0])) 78 self.mf.data.cv.cv_crit.append({}) 79 self.mf.data.results.append({}) 80 81 if self.mf.debug: 82 self.mf.log.write('%-22s\n' % ( "Checking res " + data["m1-"+self.mf.data.frq_label[self.mf.data.remap_table[0]]+"_"+self.mf.data.data_types[0]][res]['res_num'] )) 83 84 for model in self.mf.data.runs: 85 sum_cv_crit = 0 86 87 if self.mf.debug: 88 self.mf.log.write(model + "\n") 89 90 for i in range(self.mf.data.num_ri): 91 cv_model = model + "-" + self.mf.data.frq_label[self.mf.data.remap_table[i]][1] + "_" + self.mf.data.data_types[i] 92 93 real = [ float(self.mf.data.relax_data[i][res][2]) ] 94 err = [ float(self.mf.data.relax_data[i][res][3]) ] 95 types = [ [self.mf.data.data_types[i], float(self.mf.data.frq[self.mf.data.remap_table[i]])] ] 96 97 if match('m1', model): 98 back_calc = self.mf.calc_relax_data.calc(tm, model, types, [ data[cv_model][res]['s2'] ]) 99 elif match('m2', model): 100 back_calc = self.mf.calc_relax_data.calc(tm, model, types, [ data[cv_model][res]['s2'], data[cv_model][res]['te'] ]) 101 elif match('m3', model): 102 back_calc = self.mf.calc_relax_data.calc(tm, model, types, [ data[cv_model][res]['s2'], data[cv_model][res]['rex'] ]) 103 elif match('m4', model): 104 back_calc = self.mf.calc_relax_data.calc(tm, model, types, [ data[cv_model][res]['s2'], data[cv_model][res]['te'], data[cv_model][res]['rex'] ]) 105 elif match('m5', model): 106 back_calc = self.mf.calc_relax_data.calc(tm, model, types, [ data[cv_model][res]['s2f'], data[cv_model][res]['s2s'], data[cv_model][res]['te'] ]) 107 108 chi2 = self.mf.calc_chi2.relax_data(real, err, back_calc) 109 cv_crit = chi2 / (2.0 * 1.0) 110 sum_cv_crit = sum_cv_crit + cv_crit 111 112 if self.mf.debug: 113 self.mf.log.write("%7s%-10.4f%2s" % (" Chi2: ", chi2, " |")) 114 self.mf.log.write("%10s%-14.4f%2s\n\n" % (" CV crit: ", cv_crit, " |")) 115 116 self.mf.data.cv.cv_crit[res][model] = sum_cv_crit / float(len(self.mf.data.relax_data)) 117 118 if self.mf.debug: 119 self.mf.log.write("%13s%-10.4f\n\n" % ("Ave CV crit: ", sum_cv_crit/float(len(self.mf.data.relax_data)))) 120 121 # Select model. 122 min = 'm1' 123 for model in self.mf.data.runs: 124 if self.mf.data.cv.cv_crit[res][model] < self.mf.data.cv.cv_crit[res][min]: 125 min = model 126 if self.mf.data.cv.cv_crit[res][min] == float('inf'): 127 self.mf.data.results[res] = self.fill_results(data[min+"-"+self.mf.data.frq_label[self.mf.data.remap_table[0]]+"_"+self.mf.data.data_types[0]][res], model='0') 128 else: 129 self.mf.data.results[res] = self.fill_results(data[min+"-"+self.mf.data.frq_label[self.mf.data.remap_table[0]]+"_"+self.mf.data.data_types[0]][res], model=min[1]) 130 131 if self.mf.debug: 132 self.mf.log.write(self.mf.usr_param.method + " (m1): " + `self.mf.data.cv.cv_crit[res]['m1']` + "\n") 133 self.mf.log.write(self.mf.usr_param.method + " (m2): " + `self.mf.data.cv.cv_crit[res]['m2']` + "\n") 134 self.mf.log.write(self.mf.usr_param.method + " (m3): " + `self.mf.data.cv.cv_crit[res]['m3']` + "\n") 135 self.mf.log.write(self.mf.usr_param.method + " (m4): " + `self.mf.data.cv.cv_crit[res]['m4']` + "\n") 136 self.mf.log.write(self.mf.usr_param.method + " (m5): " + `self.mf.data.cv.cv_crit[res]['m5']` + "\n") 137 self.mf.log.write("The selected model is: " + min + "\n\n") 138 139 sys.stdout.write("%10s\n" % ("Model " + self.mf.data.results[res]['model']))
140 141
142 - def print_data(self):
143 "Print all the data into the 'data_all' file." 144 145 file = open('data_all', 'w') 146 file_crit = open('crit', 'w') 147 148 sys.stdout.write("[") 149 for res in range(len(self.mf.data.results)): 150 sys.stdout.write("-") 151 file.write("\n\n<<< Residue " + self.mf.data.results[res]['res_num']) 152 file.write(", Model " + self.mf.data.results[res]['model'] + " >>>\n") 153 file.write('%-20s' % '') 154 file.write('%-19s' % 'Model 1') 155 file.write('%-19s' % 'Model 2') 156 file.write('%-19s' % 'Model 3') 157 file.write('%-19s' % 'Model 4') 158 file.write('%-19s' % 'Model 5') 159 160 file_crit.write('%-6s' % self.mf.data.results[res]['res_num']) 161 file_crit.write('%-6s' % self.mf.data.results[res]['model']) 162 163 for i in range(self.mf.data.num_ri): 164 file.write("\n-" + self.mf.data.frq_label[self.mf.data.remap_table[i]][1] + "_" + self.mf.data.data_types[i]) 165 166 # S2. 167 file.write('\n%-20s' % 'S2') 168 for model in self.mf.data.runs: 169 cv_model = model + "-" + self.mf.data.frq_label[self.mf.data.remap_table[i]][1] + "_" + self.mf.data.data_types[i] 170 file.write('%9.3f' % self.mf.data.data[cv_model][res]['s2']) 171 file.write('%1s' % '±') 172 file.write('%-9.3f' % self.mf.data.data[cv_model][res]['s2_err']) 173 174 # S2f. 175 file.write('\n%-20s' % 'S2f') 176 for model in self.mf.data.runs: 177 cv_model = model + "-" + self.mf.data.frq_label[self.mf.data.remap_table[i]][1] + "_" + self.mf.data.data_types[i] 178 file.write('%9.3f' % self.mf.data.data[cv_model][res]['s2f']) 179 file.write('%1s' % '±') 180 file.write('%-9.3f' % self.mf.data.data[cv_model][res]['s2f_err']) 181 182 # S2s. 183 file.write('\n%-20s' % 'S2s') 184 for model in self.mf.data.runs: 185 cv_model = model + "-" + self.mf.data.frq_label[self.mf.data.remap_table[i]][1] + "_" + self.mf.data.data_types[i] 186 file.write('%9.3f' % self.mf.data.data[cv_model][res]['s2s']) 187 file.write('%1s' % '±') 188 file.write('%-9.3f' % self.mf.data.data[cv_model][res]['s2s_err']) 189 190 # te. 191 file.write('\n%-20s' % 'te') 192 for model in self.mf.data.runs: 193 cv_model = model + "-" + self.mf.data.frq_label[self.mf.data.remap_table[i]][1] + "_" + self.mf.data.data_types[i] 194 file.write('%9.2f' % self.mf.data.data[cv_model][res]['te']) 195 file.write('%1s' % '±') 196 file.write('%-9.2f' % self.mf.data.data[cv_model][res]['te_err']) 197 198 # Rex. 199 file.write('\n%-20s' % 'Rex') 200 for model in self.mf.data.runs: 201 cv_model = model + "-" + self.mf.data.frq_label[self.mf.data.remap_table[i]][1] + "_" + self.mf.data.data_types[i] 202 file.write('%9.3f' % self.mf.data.data[cv_model][res]['rex']) 203 file.write('%1s' % '±') 204 file.write('%-9.3f' % self.mf.data.data[cv_model][res]['rex_err']) 205 206 # Cross validation criteria. 207 file.write('\n%-20s' % 'CV') 208 for model in self.mf.data.runs: 209 file.write('%-19.3f' % self.mf.data.cv.cv_crit[res][model]) 210 211 file_crit.write('%-25s' % `self.mf.data.cv.cv_crit[res][model]`) 212 file_crit.write('\n') 213 214 file.write('\n') 215 sys.stdout.write("]\n") 216 file.close()
217 218
219 - def print_results(self):
220 "Print the results into the results file." 221 222 file = open('results', 'w') 223 file.write('%-6s%-6s\n' % ( 'ResNo', 'Model' )) 224 sys.stdout.write("[") 225 for res in range(len(self.mf.data.results)): 226 sys.stdout.write("-") 227 file.write('%-6s' % self.mf.data.results[res]['res_num']) 228 file.write('%-6s\n' % self.mf.data.results[res]['model']) 229 sys.stdout.write("]\n") 230 file.close()
231