Skip to content

Commit

Permalink
Merge pull request #236 from VukManojlovic/CTX-5524
Browse files Browse the repository at this point in the history
CTX-5524: Added tag management and inference.py hotfix
  • Loading branch information
dule1322 authored Jul 30, 2024
2 parents 6430f94 + d1f418d commit 542f343
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 6 deletions.
7 changes: 6 additions & 1 deletion coretex/entities/dataset/network_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from .dataset import Dataset
from .state import DatasetState
from ..tag import EntityTagType, Taggable
from ..sample import NetworkSample
from ..utils import isEntityNameValid
from ... import folder_manager
Expand Down Expand Up @@ -91,7 +92,7 @@ def _encryptedSampleImport(sampleType: Type[SampleType], sampleName: str, sample
raise RuntimeError("Unreachable statement was reached.")


class NetworkDataset(Generic[SampleType], Dataset[SampleType], NetworkObject, ABC):
class NetworkDataset(Generic[SampleType], Dataset[SampleType], NetworkObject, Taggable, ABC):

"""
Represents the base class for all Dataset classes which are
Expand Down Expand Up @@ -129,6 +130,10 @@ def path(self) -> Path:

return folder_manager.datasetsFolder / str(self.id)

@property
def entityTagType(self) -> EntityTagType:
return EntityTagType.dataset

# Codable overrides

@classmethod
Expand Down
7 changes: 6 additions & 1 deletion coretex/entities/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
import json
import logging

from ..tag import Taggable, EntityTagType
from ..utils import isEntityNameValid
from ... import folder_manager
from ...networking import networkManager, NetworkObject, ChunkUploadSession, MAX_CHUNK_SIZE, NetworkRequestError
from ...codable import KeyDescriptor


class Model(NetworkObject):
class Model(NetworkObject, Taggable):

"""
Represents a machine learning model object on Coretex.ai
Expand Down Expand Up @@ -82,6 +83,10 @@ def path(self) -> Path:
def zipPath(self) -> Path:
return self.path.with_suffix(".zip")

@property
def entityTagType(self) -> EntityTagType:
return EntityTagType.model

@classmethod
def modelDescriptorFileName(cls) -> str:
"""
Expand Down
152 changes: 152 additions & 0 deletions coretex/entities/tag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (C) 2023 Coretex LLC

# This file is part of Coretex.ai

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from enum import IntEnum
from typing import Optional, Dict, Any
from abc import abstractmethod, ABC

import re
import random
import logging

from ..networking import networkManager, NetworkRequestError


class EntityTagType(IntEnum):

model = 1
dataset = 2


class Taggable(ABC):

id: int
projectId: int

@property
@abstractmethod
def entityTagType(self) -> EntityTagType:
pass

def _getTagId(self, tagName: str) -> Optional[int]:
parameters = {
"name": tagName,
"type": self.entityTagType.value,
"project_id": self.projectId
}

response = networkManager.get("tag", parameters)
if response.hasFailed():
raise NetworkRequestError(response, "Failed to check existing tags")

tags = response.getJson(dict).get("data")
if not isinstance(tags, list):
raise NetworkRequestError(response, f"Field \"data\" from tag response must be dict, but got {type(tags)} instead")

if len(tags) == 0:
return None

if not isinstance(tags[0], dict):
raise NetworkRequestError(response, f"Tag object from response must be dict, but got {type(tags[0])} instead")

tagId = tags[0].get("id")
if not isinstance(tagId, int):
raise NetworkRequestError(response, f"Tag object from response must have field id of type int, but got {type(tagId)} instead")

return tagId


def addTag(self, tag: str, color: Optional[str] = None) -> None:
"""
Add a tag to this entity
Parameters
----------
tag : str
name of the tag
color : Optional[str]
a hexadecimal color code for the new tag\n
if tag already exists in project, this will be ignored\n
if left empty and tag does not already exist, a random color will be picked
Raises
------
ValueError
if tag name or color are invalid
NetworkRequestError
if request to add tag failed
"""

if re.match(r"^[a-z0-9-]{1,30}$", tag) is None:
raise ValueError(">> [Coretex] Tag has to be alphanumeric")

if color is None:
color = f"#{random.randint(0, 0xFFFFFF):06x}"
else:
if re.match(r"^#([A-Fa-f0-9]{3}|[A-Fa-f0-9]{6})$", color) is None:
raise ValueError(">> [Coretex] Tag color has to follow hexadecimal color code")

tags: Dict[str, Any] = {}

tagId = self._getTagId(tag)
if tagId is not None:
tags["existing"] = [tagId]
else:
tags["new"] = [{
"name": tag,
"color": color
}]

parameters = {
"entity_id": self.id,
"type": self.entityTagType.value,
"tags": tags
}

response = networkManager.post("tag/entity", parameters)
if response.hasFailed():
raise NetworkRequestError(response, "Failed to create tag")

def removeTag(self, tag: str) -> None:
"""
Remove tag with provided name from the entity
Parameters
----------
tag : str
name of the tag
Raises
------
NetworkRequestError
if tag removal request failed
"""

tagId = self._getTagId(tag)
if tagId is None:
logging.error(f">> [Coretex] Tag \"{tag}\" not found on entity id {self.id}")
return

parameters = {
"entity_id": self.id,
"tag_id": tagId,
"type": self.entityTagType.value
}

response = networkManager.post("tag/remove", parameters)
if response.hasFailed():
raise NetworkRequestError(response, "Failed to remove tag")
15 changes: 11 additions & 4 deletions coretex/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
async def genWitness(inputPath: Path, circuit: Path, witnessPath: Path) -> None:
await ezkl.gen_witness(inputPath, circuit, witnessPath)

async def getSrs(settings: Path) -> None:
await ezkl.get_srs(settings)


def runOnnxInference(
data: np.ndarray,
onnxPath: Path,
compiledModelPath: Optional[Path] = None,
proveKey: Optional[Path] = None
proveKey: Optional[Path] = None,
settingsPath: Optional[Path] = None,
) -> Union[np.ndarray, Tuple[np.ndarray, Path]]:

"""
Expand All @@ -36,6 +40,8 @@ def runOnnxInference(
data which will be directly fed to the model
onnxPath : Path
path to the onnx model
settingsPath : Path
path to the settigs.json file
compiledModelPath : Optional[Path]
path to the compiled model
proveKey : Optional[Path]
Expand All @@ -53,11 +59,11 @@ def runOnnxInference(
inputName = session.get_inputs()[0].name
result = np.array(session.run(None, {inputName: data}))

if compiledModelPath is None and proveKey is None:
if compiledModelPath is None and proveKey is None and settingsPath is None:
return result

if compiledModelPath is None or proveKey is None:
raise ValueError(f">> [Coretex] Parameters compiledModelPath and proveKey have to either both be passed or None")
if compiledModelPath is None or proveKey is None or settingsPath is None:
raise ValueError(f">> [Coretex] Parameters compiledModelPath, proveKey and settingsPath have to either all be passed (for verified inference) or none of them (for regula inference)")

inferenceDir = folder_manager.createTempFolder(inferenceId)
witnessPath = inferenceDir / "witness.json"
Expand All @@ -70,6 +76,7 @@ def runOnnxInference(
json.dump(inputData, file)

asyncio.run(genWitness(inputPath, compiledModelPath, witnessPath))
asyncio.run(getSrs(settingsPath))
ezkl.prove(
witnessPath,
compiledModelPath,
Expand Down

0 comments on commit 542f343

Please sign in to comment.