from __future__ import annotations
import weakref
from types import BuiltinFunctionType
from types import FunctionType as FuncType
from types import MethodDescriptorType
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
import torch
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
Strict,
ValidationError,
field_validator,
model_serializer,
)
from pydantic.functional_validators import AfterValidator, BeforeValidator
from typing_extensions import Annotated, Self
from ...intervention.graph import InterventionGraph, InterventionNode
from ...tracing.graph import Graph, Node, SubGraph
from . import FUNCTIONS_WHITELIST, get_function_name
if TYPE_CHECKING:
from ... import NNsight
FUNCTION = Union[BuiltinFunctionType, FuncType, MethodDescriptorType, type]
PRIMITIVE = Union[int, float, str, bool, None]
class DeserializeHandler:
def __init__(
self,
memo,
model: "NNsight"
) -> None:
self.memo = memo
self.model = model
self.graph = Graph(node_class=InterventionNode)
MEMO = {}
[docs]
class BaseNNsightModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
type_name: Literal["TYPE_NAME"]
@classmethod
def to_model(cls, value: Any) -> Self:
raise NotImplementedError()
def deserialize(self, handler: DeserializeHandler):
raise NotImplementedError()
def try_deserialize(value: Union[BaseNNsightModel, Any], handler: DeserializeHandler):
if isinstance(value, BaseNNsightModel):
return value.deserialize(handler)
return value
def memoized(fn):
def inner(value):
model = fn(value)
_id = id(value)
MEMO[_id] = model
return MemoReferenceModel(id=_id)
return inner
### Custom Pydantic types for all supported base types
[docs]
class NodeModel(BaseNNsightModel):
type_name: Literal["NODE"] = "NODE"
target: ValueTypes
args: List[ValueTypes] = []
kwargs: Dict[str, ValueTypes] = {}
@staticmethod
@memoized
def to_model(value: Node) -> Self:
return NodeModel(target=value.target, args=value.args, kwargs=value.kwargs)
@model_serializer(mode="wrap")
def serialize_model(self, handler):
dump = handler(self)
if not self.kwargs:
dump.pop("kwargs")
if not self.args:
dump.pop("args")
return dump
def deserialize(self, handler: DeserializeHandler) -> Node:
return handler.graph.create(
self.target.deserialize(handler),
*[try_deserialize(value, handler) for value in self.args],
**{
key: try_deserialize(value, handler)
for key, value in self.kwargs.items()
}
).node
[docs]
class TensorModel(BaseNNsightModel):
type_name: Literal["TENSOR"] = "TENSOR"
values: List
dtype: str
@staticmethod
@memoized
def to_model(value: torch.Tensor) -> Self:
return TensorModel(values=value.tolist(), dtype=str(value.dtype).split(".")[-1])
def deserialize(self, handler: DeserializeHandler) -> torch.Tensor:
dtype = getattr(torch, self.dtype)
return torch.tensor(self.values, dtype=dtype)
[docs]
class SliceModel(BaseNNsightModel):
type_name: Literal["SLICE"] = "SLICE"
start: ValueTypes
stop: ValueTypes
step: ValueTypes
@staticmethod
@memoized
def to_model(value: slice) -> Self:
return SliceModel(start=value.start, stop=value.stop, step=value.step)
def deserialize(self, handler: DeserializeHandler) -> slice:
return slice(
try_deserialize(self.start, handler),
try_deserialize(self.stop, handler),
try_deserialize(self.step, handler),
)
[docs]
class EllipsisModel(BaseNNsightModel):
type_name: Literal["ELLIPSIS"] = "ELLIPSIS"
def deserialize(
self, handler: DeserializeHandler
) -> type(
...
): # It will be better to use EllipsisType, but it requires python>=3.10
return ...
[docs]
class ListModel(BaseNNsightModel):
type_name: Literal["LIST"] = "LIST"
values: List[ValueTypes]
@staticmethod
def to_model(value: List) -> Self:
return ListModel(values=value)
def deserialize(self, handler: DeserializeHandler) -> list:
return [try_deserialize(value, handler) for value in self.values]
[docs]
class TupleModel(BaseNNsightModel):
type_name: Literal["TUPLE"] = "TUPLE"
values: List[ValueTypes]
@staticmethod
def to_model(value: Tuple) -> Self:
return TupleModel(values=value)
def deserialize(self, handler: DeserializeHandler) -> tuple:
return tuple([try_deserialize(value, handler) for value in self.values])
[docs]
class DictModel(BaseNNsightModel):
type_name: Literal["DICT"] = "DICT"
values: Dict[str, ValueTypes]
@staticmethod
def to_model(value: Dict) -> Self:
return DictModel(values=value)
def deserialize(self, handler: DeserializeHandler) -> dict:
return {
key: try_deserialize(value, handler) for key, value in self.values.items()
}
[docs]
class FunctionWhitelistError(Exception):
pass
[docs]
class FunctionModel(BaseNNsightModel):
type_name: Literal["FUNCTION"] = "FUNCTION"
function_name: str
@staticmethod
def to_model(value:FUNCTION):
model = FunctionModel(function_name=get_function_name(value))
FunctionModel.check_function_whitelist(model.function_name)
return model
@classmethod
def check_function_whitelist(cls, qualname: str) -> str:
if qualname not in FUNCTIONS_WHITELIST:
raise FunctionWhitelistError(
f"Function with name `{qualname}` not in function whitelist."
)
return qualname
def deserialize(self, handler: DeserializeHandler) -> FUNCTION:
FunctionModel.check_function_whitelist(self.function_name)
return FUNCTIONS_WHITELIST[self.function_name]
[docs]
class GraphModel(BaseNNsightModel):
type_name: Literal["GRAPH"] = "GRAPH"
# We have a reference to the real Graph in the pydantic to be used by optimization logic
graph: Graph = Field(exclude=True, default=None, validate_default=False)
nodes: List[Union[MemoReferenceModel, NodeType]]
@staticmethod
def to_model(value: Graph) -> Self:
return GraphModel(graph=value, nodes=value.nodes)
def deserialize(self, handler: DeserializeHandler) -> Graph:
for node in self.nodes:
node.deserialize(handler)
return handler.graph
[docs]
class SubGraphModel(BaseNNsightModel):
type_name: Literal["SUBGRAPH"] = "SUBGRAPH"
subset: List[int]
@staticmethod
def to_model(value: SubGraph) -> Self:
return SubGraphModel(subset=value.subset)
def deserialize(self, handler: DeserializeHandler) -> Graph:
value = SubGraph(handler.graph, subset=self.subset)
for node in value:
node.graph = value
return value
[docs]
class InterventionGraphModel(SubGraphModel):
type_name: Literal["INTERVENTIONGRAPH"] = "INTERVENTIONGRAPH"
@staticmethod
def to_model(value: InterventionGraph) -> Self:
return InterventionGraphModel(subset=value.subset)
def deserialize(self, handler: DeserializeHandler) -> Graph:
value = InterventionGraph(handler.graph, model=handler.model, subset=self.subset)
for node in value:
node.graph = value
return value
[docs]
class MemoReferenceModel(BaseNNsightModel):
type_name: Literal["REFERENCE"] = "REFERENCE"
id: int
def deserialize(self, handler: DeserializeHandler):
value = try_deserialize(handler.memo[self.id], handler)
handler.memo[self.id] = value
return value
### Define Annotated types to convert objects to their custom Pydantic counterpart
GraphType = Annotated[
Graph,
AfterValidator(GraphModel.to_model),
]
SubGraphType = Annotated[
SubGraph,
AfterValidator(SubGraphModel.to_model),
]
InterventionGraphType = Annotated[
InterventionGraph,
AfterValidator(InterventionGraphModel.to_model),
]
TensorType = Annotated[torch.Tensor, AfterValidator(TensorModel.to_model)]
SliceType = Annotated[
slice,
AfterValidator(SliceModel.to_model),
]
EllipsisType = Annotated[
type(...), # It will be better to use EllipsisType, but it requires python>=3.10
AfterValidator(lambda _: EllipsisModel()),
]
ListType = Annotated[list, AfterValidator(ListModel.to_model)]
TupleType = Annotated[
tuple,
Strict(),
AfterValidator(TupleModel.to_model),
]
DictType = Annotated[dict, AfterValidator(DictModel.to_model)]
FunctionType = Annotated[
FUNCTION,
AfterValidator(FunctionModel.to_model),
]
NodeType = Annotated[
Node,
AfterValidator(NodeModel.to_model),
]
def check_memo(object: Any):
_id = id(object)
if _id in MEMO:
return MemoReferenceModel(id=_id)
raise ValueError()
MemoType = Annotated[object, BeforeValidator(check_memo)]
### Register all custom Pydantic objects to convert objects to
TOTYPES = Annotated[
Union[
MemoReferenceModel,
NodeModel,
SliceModel,
TensorModel,
TupleModel,
ListModel,
DictModel,
FunctionModel,
EllipsisModel,
InterventionGraphModel,
SubGraphModel,
GraphModel,
],
Field(discriminator="type_name"),
]
### Register all Annotated types objects to convert objects from
FROMTYPES = Annotated[
Union[
MemoType,
NodeType,
InterventionGraphType,
SubGraphType,
GraphType,
FunctionType,
SliceType,
TensorType,
TupleType,
ListType,
DictType,
EllipsisType,
],
Field(union_mode="left_to_right"),
]
### Final registration
ValueTypes = Union[
PRIMITIVE,
TOTYPES,
FROMTYPES,
]