00001 import os
00002
00003 import ete2
00004 import TreeLayouts
00005
00006
00007
00008
00009 reload(TreeLayouts)
00010
00011 mean = lambda seq: float(sum(seq))/len(seq)
00012
00013 DefaultNames = "-", "NoName"
00014
00015
00016
00017
00018
00019 class Tree(ete2.Tree):
00020
00021 def __init__(self, newick=None, format=1):
00022
00023 ete2.Tree.__init__(self,newick=newick,format=format)
00024
00025 if newick != None and newick[-2:] != ");":
00026 name = newick.split(os.sep)[-1].split(".")[0]
00027 self.name = name
00028
00029
00030
00031 self.IsRoot = self.is_root
00032 self.IsLeaf = self.is_leaf
00033 self.GetRoot = self.get_tree_root
00034
00035 self.LayoutData = None
00036
00037
00038 def Copy(self):
00039
00040 return Tree(self.NWStr())
00041
00042
00043 def IsInternal(self):
00044 return True not in (self.is_leaf(), self.is_root(), not self.name in DefaultNames)
00045
00046
00047 def HasChild(self, node):
00048
00049 for ch in self.traverse():
00050 if ch == node:
00051 return True
00052 return False
00053
00054
00055 def IsModule(self):
00056 """ A module is assumed to be an internal node
00057 with a non-default name """
00058 return True not in (self.is_leaf(), self.is_root(), self.name in DefaultNames)
00059
00060
00061 def HasModule(self):
00062
00063 for n in self.traverse():
00064 if n.IsModule() and n != self:
00065 return True
00066 return False
00067
00068
00069 def IsTerminalModule(self):
00070 """ is a module that has no modules in descendants """
00071
00072 return self.IsModule() and (not self.HasModule())
00073
00074
00075 def GetModules(self):
00076
00077 rv = {}
00078 for n in self.traverse():
00079 if n.IsModule():
00080 mod = Tree(n.NWStr())
00081 mod.dist = n.MeanDist()
00082 mod.name = n.name
00083 mod.Rename()
00084 rv[n.name] = mod
00085
00086 return rv
00087
00088
00089
00090
00091 def GetNodesByName(self,Names):
00092 """ pre: Names is a list of strings
00093 post: GetNodesByName(Names) == list of nodes with name in Names
00094 """
00095
00096 searchres = map(lambda n: self.search_nodes(name=n),Names)
00097 searchits = filter(lambda L: len(L)>0, searchres)
00098
00099 return map(lambda L:L[0], searchits)
00100
00101 def GetLeafNamesInOrder(self):
00102 rv = []
00103 for ch in self.children:
00104 if ch.IsLeaf():
00105 rv.append(ch.name)
00106 else:
00107 rv.extend(ch.GetLeafNamesInOrder())
00108 return rv
00109
00110
00111
00112 def GetCommonAncestor(self, Names):
00113
00114 targs = self.GetNodesByName(Names)
00115 current = targs[0].up
00116 hits = current.GetNodesByName(Names)
00117
00118 while len(hits) < len(targs):
00119 current = current.up
00120 hits = current.GetNodesByName(Names)
00121
00122 return current
00123
00124
00125
00126 def PruneNodesMissing(self, Names):
00127
00128 for ch in self.children[:]:
00129 if len(ch.GetNodesByName(Names)) == 0:
00130 self.children.remove(ch)
00131 else:
00132 ch.PruneNodesMissing(Names)
00133
00134
00135 def GetMinTreeHolding(self, Names):
00136
00137 rv = self.GetCommonAncestor(Names).Copy()
00138 rv.PruneNodesMissing(Names)
00139 rv.MergeSingles()
00140 rv.SetName("-")
00141
00142 return rv
00143
00144
00145 def HasOnlyLeaves(self):
00146
00147 return False not in map(lambda ch: ch.IsLeaf(), self.children)
00148
00149
00150 def DistToFurthest(self):
00151 return self.get_farthest_leaf()[1]
00152
00153 def DistToFurthestChild(self):
00154
00155 dists = map(lambda node: node.dist, self.get_children())
00156 return max(dists)
00157
00158
00159 def MeanDist(self):
00160
00161 return mean(map(lambda n: n.dist, self.traverse()))
00162
00163
00164 def Consolidate(self, thresh=0.0):
00165
00166 def RecCons(n, thresh):
00167 if n.dist <= thresh and not n.is_root():
00168 RecCons(n.up,thresh)
00169 n.delete()
00170
00171 for leaf in self.iter_leaves():
00172 RecCons(leaf.up, thresh)
00173
00174
00175 def MergeSingles(self):
00176
00177 for ch in self.children[:]:
00178 if len(ch.children) == 1:
00179 ch.MergeSingles()
00180 self.dist += ch.dist
00181 ch.delete()
00182
00183 for ch in self.children[:]:
00184 ch.MergeSingles()
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201 def Save(self, features=None, outfile=None, format=1):
00202
00203 if outfile==None:
00204 outfile=self.name + ".nw"
00205
00206 ignore = ete2.Tree.write(self, features, outfile, format)
00207
00208
00209 def NWStr(self, format=1):
00210
00211 return self.write(format=format)
00212
00213
00214 def Rename(self):
00215
00216 for n in self.traverse():
00217 if n.name == "NoName":
00218 n.name="-"
00219
00220 def SetName(self, name):
00221
00222 self.name = name
00223
00224
00225 def nLeafStr(self):
00226 """ num leaves as a string """
00227
00228 return "(" + str(self.nLeaf) + ")"
00229
00230 def GetNodeType(self):
00231
00232 return {
00233 self.IsRoot() : TreeLayouts.Root,
00234 self.IsInternal(): TreeLayouts.Internal,
00235 self.IsModule() : TreeLayouts.Module,
00236 self.IsLeaf() : TreeLayouts.Leaf
00237 }[True]
00238
00239
00240 def _ExpandAll(self):
00241
00242 for node in self.traverse():
00243 node.collapsed = False
00244
00245
00246 def _SetnLeaf(self):
00247
00248 self._ExpandAll()
00249
00250 for node in self.traverse():
00251 node.nLeaf = len(node.get_leaves())
00252
00253
00254 def _CopyDists(self):
00255 """ copy node distances, so that they can be restored if modified by layout functions """
00256
00257 for n in self.traverse():
00258 n._Dist = n.dist
00259
00260
00261 def _RestoreDists(self):
00262 """ restore dists modified by layout functions """
00263
00264 for n in self.traverse():
00265 n.dist = n._Dist
00266
00267
00268
00269 def SetZeroLeafMean(self):
00270
00271 dist = self.MeanDist()
00272 for leaf in self.get_leaves():
00273 if leaf.dist == 0:
00274 leaf.dist = dist
00275
00276 def SetAllLeafMean(self):
00277
00278 dist = self.MeanDist()
00279 for leaf in self.get_leaves():
00280 leaf.dist = dist
00281
00282 def SetAllDists(self,dist):
00283
00284 for n in self.traverse():
00285 n.dist = dist
00286
00287
00288 def SetAllDistsMean(self):
00289 self.SetAllDists(self.MeanDist())
00290
00291
00292
00293
00294 def SetImgProps(self,Props=None):
00295
00296 pass
00297
00298
00299 def _Render(self, Layout=None, KeepCollapsed=False, ImgProps=None, FileName=None, Width=None, Height=None):
00300
00301 Header = None
00302
00303 self._CopyDists()
00304 self._SetnLeaf()
00305
00306 self.SetImgProps(ImgProps)
00307 layout = TreeLayouts.Layouts[Layout]
00308 if FileName == None:
00309 self.show(layout = layout)
00310 else:
00311 self.render(file_name = FileName,
00312 layout = layout,
00313 header = Header,
00314 w = Width,
00315 h = Height,
00316 img_properties = TreeLayouts.ImgProps())
00317
00318 if not KeepCollapsed:
00319 self._ExpandAll()
00320 self._RestoreDists()
00321
00322
00323 def Show(self, Layout="Explore",KeepCollapsed=False):
00324
00325 self._Render(Layout=Layout,KeepCollapsed=KeepCollapsed)
00326
00327
00328 def SavePDF(self, fname=None, Layout="Explore",KeepCollapsed=False):
00329
00330 if fname == None:
00331 fname = self.name
00332
00333 fname += "_" + Layout + ".pdf"
00334 fname = fname.replace(" ","_")
00335 self._Render(Layout=Layout, FileName=fname,KeepCollapsed=KeepCollapsed)
00336
00337
00338
00339