1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 """Module for selecting the best model."""
24
25
26 import sys
27
28
29 from lib.errors import RelaxError, RelaxPipeError
30 from lib.io import write_data
31 from lib.model_selection import aic, aicc, bic
32 from pipe_control import interatomic, mol_res_spin
33 import pipe_control.pipes
34 from pipe_control.pipes import has_pipe, pipe_names, switch
35 from specific_analyses.api import return_api
36
37
38 -def select(method=None, modsel_pipe=None, bundle=None, pipes=None):
39 """Model selection function.
40
41 @keyword method: The model selection method. This can currently be one of:
42 - 'AIC', Akaike's Information Criteria.
43 - 'AICc', Small sample size corrected AIC.
44 - 'BIC', Bayesian or Schwarz Information Criteria.
45 - 'CV', Single-item-out cross-validation.
46 None of the other model selection techniques are currently supported.
47 @type method: str
48 @keyword modsel_pipe: The name of the new data pipe to be created by copying of the selected data pipe.
49 @type modsel_pipe: str
50 @keyword bundle: The optional data pipe bundle to associate the newly created pipe with.
51 @type bundle: str or None
52 @keyword pipes: A list of the data pipes to use in the model selection.
53 @type pipes: list of str
54 """
55
56
57 if has_pipe(modsel_pipe):
58 raise RelaxPipeError(modsel_pipe)
59
60
61 if pipes == None:
62
63 pipes = pipe_names()
64
65
66 if method == 'AIC':
67 print("AIC model selection.")
68 formula = aic
69 elif method == 'AICc':
70 print("AICc model selection.")
71 formula = aicc
72 elif method == 'BIC':
73 print("BIC model selection.")
74 formula = bic
75 elif method == 'CV':
76 print("CV model selection.")
77 raise RelaxError("The model selection technique " + repr(method) + " is not currently supported.")
78 else:
79 raise RelaxError("The model selection technique " + repr(method) + " is not currently supported.")
80
81
82 if len(pipes) == 0:
83 raise RelaxError("No data pipes are available for use in model selection.")
84
85
86 function_type = {}
87 model_loop = {}
88 model_type = {}
89 duplicate_data = {}
90 model_statistics = {}
91 skip_function = {}
92 modsel_pipe_exists = False
93
94
95 if isinstance(pipes[0], list):
96
97 if len(pipes[0]) == 0:
98 raise RelaxError("No pipes are available for use in model selection in the array " + repr(pipes[0]) + ".")
99
100
101 for i in range(len(pipes)):
102 for j in range(len(pipes[i])):
103
104 api = return_api(pipe_name=pipes[i][j])
105
106
107 model_loop[pipes[i][j]] = api.model_loop
108 model_type[pipes[i][j]] = api.model_type
109 duplicate_data[pipes[i][j]] = api.duplicate_data
110 model_statistics[pipes[i][j]] = api.model_statistics
111 skip_function[pipes[i][j]] = api.skip_function
112
113
114 for i in range(len(pipes)):
115 for j in range(len(pipes[i])):
116 if model_loop[pipes[0][j]] != model_loop[pipes[i][j]]:
117 raise RelaxError("The models for each data pipes should be the same.")
118
119
120 api = return_api(pipe_name=pipes[0][0])
121 model_loop = api.model_loop
122 model_desc = api.model_desc
123
124
125 global_flag = False
126 for i in range(len(pipes)):
127 for j in range(len(pipes[i])):
128 if model_type[pipes[i][j]]() == 'global':
129 global_flag = True
130
131
132 else:
133
134 for i in range(len(pipes)):
135
136 api = return_api()
137
138
139 model_loop[pipes[i]] = api.model_loop
140 model_type[pipes[i]] = api.model_type
141 duplicate_data[pipes[i]] = api.duplicate_data
142 model_statistics[pipes[i]] = api.model_statistics
143 skip_function[pipes[i]] = api.skip_function
144
145
146 api = return_api(pipe_name=pipes[0])
147 model_loop = api.model_loop
148 model_desc = api.model_desc
149
150
151 global_flag = False
152 for j in range(len(pipes)):
153 if model_type[pipes[j]]() == 'global':
154 global_flag = True
155
156
157
158 for model_info in model_loop():
159
160 print("\n")
161 desc = model_desc(model_info)
162 if desc:
163 print(desc)
164
165
166 best_model = None
167 best_crit = 1e300
168 data = []
169
170
171 for j in range(len(pipes)):
172
173 if method == 'CV':
174
175 sum_crit = 0.0
176
177
178 for k in range(len(pipes[j])):
179
180 pipe = pipes[j][k]
181
182
183 switch(pipe)
184
185
186 if skip_function[pipe](model_info):
187 continue
188
189
190 k, n, chi2 = model_statistics[pipe](model_info)
191
192
193 if k == None or n == None or chi2 == None:
194 continue
195
196
197 sum_crit = sum_crit + chi2
198
199
200 crit = sum_crit / float(len(pipes[j]))
201
202
203 else:
204
205 pipe = pipes[j]
206
207
208 switch(pipe)
209
210
211 if skip_function[pipe](model_info):
212 continue
213
214
215 k, n, chi2 = model_statistics[pipe](model_info, global_stats=global_flag)
216
217
218 if k == None or n == None or chi2 == None:
219 continue
220
221
222 crit = formula(chi2, float(k), float(n))
223
224
225 data.append([pipe, repr(k), repr(n), "%.5f" % chi2, "%.5f" % crit])
226
227
228 if crit < best_crit:
229 best_model = pipe
230 best_crit = crit
231
232
233 write_data(out=sys.stdout, headings=["Data pipe", "Num_params_(k)", "Num_data_sets_(n)", "Chi2", "Criterion"], data=data)
234
235
236 if best_model != None:
237
238 print("The model from the data pipe " + repr(best_model) + " has been selected.")
239
240
241 switch(best_model)
242
243
244 duplicate_data[best_model](best_model, modsel_pipe, model_info, global_stats=global_flag, verbose=False)
245
246
247 modsel_pipe_exists = True
248
249
250 else:
251
252 print("No model has been selected.")
253
254
255 if modsel_pipe_exists:
256 switch(modsel_pipe)
257
258
259 if bundle:
260 pipe_control.pipes.bundle(bundle=bundle, pipe=modsel_pipe)
261
262
263 mol_res_spin.metadata_update()
264 interatomic.metadata_update()
265