Skip to content

Commit

Permalink
Merge pull request #8 from pinellolab/fix_op
Browse files Browse the repository at this point in the history
Add FixOperator for fake node training
  • Loading branch information
huidongchen authored Nov 11, 2022
2 parents e0d3bbf + e15675a commit 4cd0f08
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions torchbiggraph/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def forward(self, embeddings: FloatTensorType) -> FloatTensorType:
def get_operator_params_for_reg(self) -> Optional[FloatTensorType]:
return None

@OPERATORS.register_as("fix")
class FixOperator(AbstractOperator):
# Detach node tensor that the loss isn't propagated to the node embedding.
def forward(self, embeddings: FloatTensorType) -> FloatTensorType:
match_shape(embeddings, ..., self.dim)
return embeddings.clone().detach()

def get_operator_params_for_reg(self) -> Optional[FloatTensorType]:
return None

@OPERATORS.register_as("diagonal")
class DiagonalOperator(AbstractOperator):
Expand Down

0 comments on commit 4cd0f08

Please sign in to comment.