-
-
Notifications
You must be signed in to change notification settings - Fork 132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
change the pdb_paths working style and support for loading both local… #214
base: master
Are you sure you want to change the base?
Changes from 2 commits
524beb2
f5b017c
b24bdeb
ce6d36b
f5da2c4
9a67631
07cd92a
a3dfff8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,12 +41,12 @@ def __init__( | |
self, | ||
root: str, | ||
name: str, | ||
pdb_paths: Optional[List[str]] = None, | ||
pdb_codes: Optional[List[str]] = None, | ||
uniprot_ids: Optional[List[str]] = None, | ||
graph_label_map: Optional[Dict[str, torch.Tensor]] = None, | ||
node_label_map: Optional[Dict[str, torch.Tensor]] = None, | ||
chain_selection_map: Optional[Dict[str, List[str]]] = None, | ||
pdb_paths: Optional[List[str]] = [], | ||
pdb_codes: Optional[List[str]] = [], | ||
uniprot_ids: Optional[List[str]] = [], | ||
graph_labels: Optional[List[torch.Tensor]] = None, | ||
node_labels: Optional[List[torch.Tensor]] = None, | ||
chain_selections: Optional[List[str]] = None, | ||
graphein_config: ProteinGraphConfig = ProteinGraphConfig(), | ||
graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor( | ||
src_format="nx", dst_format="pyg" | ||
|
@@ -73,13 +73,13 @@ def __init__( | |
:type root: str | ||
:param name: Name of the dataset. Will be saved to ``data_$name.pt``. | ||
:type name: str | ||
:param pdb_paths: List of full path of pdb files to load. Defaults to ``None``. | ||
:param pdb_paths: List of full path of pdb files to load. Defaults to ``List``. | ||
:type pdb_paths: Optional[List[str]], optional | ||
:param pdb_codes: List of PDB codes to download and parse from the PDB. | ||
Defaults to None. | ||
Defaults to List. | ||
:type pdb_codes: Optional[List[str]], optional | ||
:param uniprot_ids: List of Uniprot IDs to download and parse from | ||
Alphafold Database. Defaults to ``None``. | ||
Alphafold Database. Defaults to ``List``. | ||
:type uniprot_ids: Optional[List[str]], optional | ||
:param graph_label_map: Dictionary mapping PDB/Uniprot IDs to | ||
graph-level labels. Defaults to ``None``. | ||
|
@@ -130,54 +130,56 @@ def __init__( | |
self.pdb_codes = ( | ||
[pdb.lower() for pdb in pdb_codes] | ||
if pdb_codes is not None | ||
else None | ||
else [] | ||
) | ||
self.uniprot_ids = ( | ||
[up.upper() for up in uniprot_ids] | ||
if uniprot_ids is not None | ||
else None | ||
else [] | ||
) | ||
|
||
self.pdb_paths = pdb_paths | ||
if self.pdb_paths is None: | ||
if self.pdb_codes and self.uniprot_ids: | ||
self.structures = self.pdb_codes + self.uniprot_ids | ||
elif self.pdb_codes: | ||
self.structures = pdb_codes | ||
elif self.uniprot_ids: | ||
self.structures = uniprot_ids | ||
# Use local saved pdb_files instead of download or move them to self.root/raw dir | ||
else: | ||
if isinstance(self.pdb_paths, list): | ||
self.structures = [ | ||
# make sure root path is unique | ||
if self.pdb_paths: | ||
# add pdb_paths' name into self.structure | ||
self.pdb_paths_name = [ | ||
os.path.splitext(os.path.split(pdb_path)[-1])[0] | ||
for pdb_path in self.pdb_paths | ||
] | ||
self.pdb_path, _ = os.path.split(self.pdb_paths[0]) | ||
|
||
if self.pdb_codes and self.uniprot_ids: | ||
self.structures = self.pdb_codes + self.uniprot_ids | ||
elif self.pdb_codes: | ||
self.structures = pdb_codes | ||
elif self.uniprot_ids: | ||
self.structures = uniprot_ids | ||
self.af_version = af_version | ||
# if root pdb_path is not unique raise error since we will save all pdb into this root pdb_path and take it as the self.raw_dir | ||
if len(set([os.path.split(pdb_path)[0] for pdb_path in pdb_paths])) != 1: | ||
raise ValueError("pdb_paths should have only one root path not so much!") | ||
else: | ||
self.pdb_paths_name = [] | ||
|
||
self.structures = list(set(self.pdb_codes + self.uniprot_ids + self.pdb_paths_name)) # remove some pdb_codes is in pdb_path and loaded repeately | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this should be a set operation. With chain selections you may want to have e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, i guess it would'n make a difference at chain selection, this set operation is to drop duplicate in the result list of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It becomes a problem here though (L283), no? def process(self):
"""Process structures into PyG format and save to disk."""
# Read data into huge `Data` list.
structure_files = [
f"{self.raw_dir}/{pdb}.pdb" for pdb in self.structures
] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, i guess not. from graphein.ml.datasets import InMemoryProteinGraphDataset
local_dir = "../protein/test_data"
pdb_paths = [osp.join(local_dir, pdb_file) for pdb_file in os.listdir(local_dir) if pdb_file.endswith(".pdb")]
ds = InMemoryProteinGraphDataset(root = "../protein/test_data/InMemoryProteinGraphDataset",
name = "InMemoryProteinGraphDataset_test",
pdb_paths=pdb_paths,
pdb_codes=["10gs"],
uniprot_ids=["A0A6J1BG53", "A0A6P5Z5F7"],
af_version=3) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you see what happens with: from graphein.ml.dataset import InMemoryProteinGraphDataset
ds = InMemoryProteinGraphDataset(root = ""../protein.test_data/InMemoryProteinGraphDataset", pdb_paths=pdb_paths, pdb_codes = ["4hhb", "4hhb"], chain_selection=["A","B"]) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well, i'll try later There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
and why i I guess this may need lots of change~ |
||
|
||
# Labels & Chains | ||
if graph_labels is not None: | ||
self.graph_label_map = dict(enumerate(graph_labels)) | ||
else: | ||
self.graph_label_map = None | ||
|
||
if node_labels is not None: | ||
self.node_label_map = dict(enumerate(node_labels)) | ||
else: | ||
self.node_label_map = None | ||
if chain_selections is not None: | ||
self.chain_selection_map = dict(enumerate(chain_selections)) | ||
else: | ||
self.chain_selection_map = None | ||
self.validate_input() | ||
self.bad_pdbs: List[ | ||
str | ||
] = [] # list of pdb codes that failed to download | ||
|
||
# Labels & Chains | ||
self.graph_label_map = graph_label_map | ||
self.node_label_map = node_label_map | ||
self.chain_selection_map = chain_selection_map | ||
|
||
# Configs | ||
self.config = graphein_config | ||
self.graph_format_convertor = graph_format_convertor | ||
self.graph_transformation_funcs = graph_transformation_funcs | ||
self.pdb_transform = pdb_transform | ||
self.num_cores = num_cores | ||
self.af_version = af_version | ||
|
||
super().__init__( | ||
root, | ||
transform=transform, | ||
|
@@ -200,10 +202,34 @@ def processed_file_names(self) -> List[str]: | |
@property | ||
def raw_dir(self) -> str: | ||
if self.pdb_paths is not None: | ||
return self.pdb_path # replace raw dir with user local pdb_path | ||
# replace raw dir with user local pdb_path; so pdb_paths should be located in the same place | ||
self.pdb_path, _ = os.path.split(self.pdb_paths[0]) | ||
return self.pdb_path | ||
else: | ||
return os.path.join(self.root, "raw") | ||
|
||
def validate_input(self): | ||
if self.graph_label_map is not None: | ||
assert len(self.structures) == len( | ||
self.graph_label_map | ||
), "Number of proteins and graph labels must match" | ||
if self.node_label_map is not None: | ||
assert len(self.structures) == len( | ||
self.node_label_map | ||
), "Number of proteins and node labels must match" | ||
if self.chain_selection_map is not None: | ||
assert len(self.structures) == len( | ||
self.chain_selection_map | ||
), "Number of proteins and chain selections must match" | ||
assert len( | ||
{ | ||
f"{pdb}_{chain}" | ||
for pdb, chain in zip( | ||
self.structures, self.chain_selection_map | ||
) | ||
} | ||
) == len(self.structures), "Duplicate protein/chain combinations" | ||
|
||
def download(self): | ||
"""Download the PDB files from RCSB or Alphafold.""" | ||
self.config.pdb_dir = Path(self.raw_dir) | ||
|
@@ -225,6 +251,7 @@ def download(self): | |
for pdb in set(self.pdb_codes) | ||
if not os.path.exists(Path(self.raw_dir) / f"{pdb}.pdb") | ||
] | ||
print("downloading uniprotids") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ohhhhhhhhhhh, too sry for these I'll remove them today |
||
if self.uniprot_ids: | ||
[ | ||
download_alphafold_structure( | ||
|
@@ -237,6 +264,7 @@ def download(self): | |
] | ||
|
||
def __len__(self) -> int: | ||
"""Returns length of data set (number of structures).""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should return the number of examples (not just the number of structures for the multiple chain reason I mentioned previously) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok |
||
return len(self.structures) | ||
|
||
def transform_pdbs(self): | ||
|
@@ -327,15 +355,12 @@ class ProteinGraphDataset(Dataset): | |
def __init__( | ||
self, | ||
root: str, | ||
pdb_paths: Optional[List[str]] = None, | ||
pdb_codes: Optional[List[str]] = None, | ||
uniprot_ids: Optional[List[str]] = None, | ||
# graph_label_map: Optional[Dict[str, int]] = None, | ||
pdb_paths: Optional[List[str]] = [], | ||
pdb_codes: Optional[List[str]] = [], | ||
uniprot_ids: Optional[List[str]] = [], | ||
graph_labels: Optional[List[torch.Tensor]] = None, | ||
node_labels: Optional[List[torch.Tensor]] = None, | ||
chain_selections: Optional[List[str]] = None, | ||
# node_label_map: Optional[Dict[str, int]] = None, | ||
# chain_selection_map: Optional[Dict[str, List[str]]] = None, | ||
graphein_config: ProteinGraphConfig = ProteinGraphConfig(), | ||
graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor( | ||
src_format="nx", dst_format="pyg" | ||
|
@@ -356,22 +381,20 @@ def __init__( | |
|
||
:param root: Root directory where the dataset should be saved. | ||
:type root: str | ||
:param pdb_paths: List of full path of pdb files to load. Defaults to ``None``. | ||
:param pdb_paths: List of full path of pdb files to load. Defaults to ``List``. | ||
:type pdb_paths: Optional[List[str]], optional | ||
:param pdb_codes: List of PDB codes to download and parse from the PDB. | ||
Defaults to ``None``. | ||
Defaults to ``List``. | ||
:type pdb_codes: Optional[List[str]], optional | ||
:param uniprot_ids: List of Uniprot IDs to download and parse from | ||
Alphafold Database. Defaults to ``None``. | ||
Alphafold Database. Defaults to ``List``. | ||
:type uniprot_ids: Optional[List[str]], optional | ||
:param graph_label_map: Dictionary mapping PDB/Uniprot IDs to | ||
graph-level labels. Defaults to ``None``. | ||
:type graph_label_map: Optional[Dict[str, Tensor]], optional | ||
:param node_label_map: Dictionary mapping PDB/Uniprot IDs to node-level | ||
labels. Defaults to ``None``. | ||
:type node_label_map: Optional[Dict[str, torch.Tensor]], optional | ||
:param chain_selection_map: Dictionary mapping, defaults to ``None``. | ||
:type chain_selection_map: Optional[Dict[str, List[str]]], optional | ||
:param graph_labels: List mapping to self.structures by index to graph-level labels. Defaults to ``None``. | ||
:type graph_labels: Optional[List[torch.Tensor]], optional | ||
:param node_labels: List mapping to self.structures by index to node-level labels. Defaults to ``None``. | ||
:type node_labels: Optional[List[torch.Tensor]], optional | ||
:param chain_selections: List mapping to self.structures by index to chain selection, defaults to ``None``. | ||
:type chain_selections: Optional[List[str]], optional | ||
:param graphein_config: Protein graph construction config, defaults to | ||
``ProteinGraphConfig()``. | ||
:type graphein_config: ProteinGraphConfig, optional | ||
|
@@ -412,34 +435,32 @@ def __init__( | |
self.pdb_codes = ( | ||
[pdb.lower() for pdb in pdb_codes] | ||
if pdb_codes is not None | ||
else None | ||
else [] | ||
) | ||
self.uniprot_ids = ( | ||
[up.upper() for up in uniprot_ids] | ||
if uniprot_ids is not None | ||
else None | ||
else [] | ||
) | ||
self.pdb_paths = pdb_paths | ||
if self.pdb_paths is None: | ||
if self.pdb_codes and self.uniprot_ids: | ||
self.structures = self.pdb_codes + self.uniprot_ids | ||
elif self.pdb_codes: | ||
self.structures = pdb_codes | ||
elif self.uniprot_ids: | ||
self.structures = uniprot_ids | ||
# Use local saved pdb_files instead of download or move them to self.root/raw dir | ||
else: | ||
if isinstance(self.pdb_paths, list): | ||
self.structures = [ | ||
# make sure root path is unique | ||
if self.pdb_paths: | ||
# add pdb_paths' name into self.structure | ||
self.pdb_paths_name = [ | ||
os.path.splitext(os.path.split(pdb_path)[-1])[0] | ||
for pdb_path in self.pdb_paths | ||
] | ||
self.pdb_path, _ = os.path.split(self.pdb_paths[0]) | ||
|
||
# Labels & Chains | ||
# if root pdb_path is not unique raise error since we will save all pdb into this root pdb_path and take it as the self.raw_dir | ||
if len(set([os.path.split(pdb_path)[0] for pdb_path in pdb_paths])) != 1: | ||
raise ValueError("pdb_paths should have only one root path not so much!") | ||
else: | ||
self.pdb_paths_name = [] | ||
|
||
self.structures = list(set(self.pdb_codes + self.uniprot_ids + self.pdb_paths_name)) # remove some pdb_codes is in pdb_path and loaded repeately | ||
|
||
self.examples: Dict[int, str] = dict(enumerate(self.structures)) | ||
|
||
# Labels & Chains | ||
if graph_labels is not None: | ||
self.graph_label_map = dict(enumerate(graph_labels)) | ||
else: | ||
|
@@ -460,9 +481,9 @@ def __init__( | |
# Configs | ||
self.config = graphein_config | ||
self.graph_format_convertor = graph_format_convertor | ||
self.num_cores = num_cores | ||
self.pdb_transform = pdb_transform | ||
self.graph_transformation_funcs = graph_transformation_funcs | ||
self.pdb_transform = pdb_transform | ||
self.num_cores = num_cores | ||
self.af_version = af_version | ||
super().__init__( | ||
root, | ||
|
@@ -492,8 +513,10 @@ def processed_file_names(self) -> List[str]: | |
|
||
@property | ||
def raw_dir(self) -> str: | ||
if self.pdb_paths is not None: | ||
return self.pdb_path # replace raw dir with user local pdb_path | ||
if self.pdb_paths: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually think it would be useful to allow users to choose a path for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree. I'm not sure about this, i prefer to dict, which key is the names like
If something wrong in my understanding, please tell me 😄 , i'm still reading and learning your code lol. It's really a pythonic code, i learnt a lot 👍 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't think this is the best idea. I think being explicit about the paths users want to use is best. For instance, people may want to use only a subset of their dataset (rather than everything in the directory - e.g. imagine where you want to keep all your pdb files together but train/test on different subsets). It also has the potential problem with hidden files like
This was my initial implementation. However, this ran into the problem where you may have different examples in your dataset drawn from different chains of the same PDB. E.g. imagine you have
Thanks!! Me too! |
||
# replace raw dir with user local pdb_path; so pdb_paths should be located in the same place | ||
self.pdb_path, _ = os.path.split(self.pdb_paths[0]) | ||
return self.pdb_path | ||
else: | ||
return os.path.join(self.root, "raw") | ||
|
||
|
@@ -610,7 +633,6 @@ def divide_chunks(l: List[str], n: int = 2) -> Generator: | |
) | ||
if self.graph_transformation_funcs is not None: | ||
graphs = [self.transform_graphein_graphs(g) for g in graphs] | ||
|
||
# Convert to PyTorch Geometric Data | ||
graphs = [self.graph_format_convertor(g) for g in graphs] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reasoning for using empty lists as the default arg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
empty lists
can add together even if they are empty, whileNone
can't. So we can skip someif
for different statements of the user passpdb_paths
orpdb_codes
oruniprot_ids
, and just merge them intoself.structures,
which is used atprocess
func and it works likeos.listdir(self.raw_dir)
.As for some potential bugs, i'm really not sure would this will cause some bugs as i use
empty list
instead ofNone
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be
None
https://stackoverflow.com/questions/366422/what-is-the-pythonic-way-to-avoid-default-parameters-that-are-empty-lists
If you want to retain the behaviour inside the object, you could do:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and actually i have done that in the latest commit