nnsight.contexts#
The contexts module contains logic for managing the tracing and running of models with nnsight.tracing
and nnsight.envoy
The primary two classes involved here are Tracer
and Invoker
.
The Tracer
class creates a Graph
around the underlying model of an NNsight
. The graph tracks and manages the operations performed on the inputs and outputs of said model.
Module’s envoys in the model expose their .output
and .input
attributes which when accessed, add to the computation graph of the tracer.
To do this, they need to know about the current Tracer object, so each module’s envoy’s .tracer
object is set to be the current Tracer.
The Tracer object also keeps track of the batch_size of the most recent input, the generation index for multi iteration generations, and all of the inputs made during its context in the .batched_input
attribute. Inputs added to this attribute should be in a format where each index is a batch index and allows the model to batch all of the inputs together.
This is to keep things consistent. If two different inputs are in two different valid formats, they both become the same format and are easy to batch. In the case of LanguageModels, regardless of whether the input are string prompts, pre-processed dictionaries, or input ids, the batched input is only input ids. On exiting the Tracer context, the Tracer object should use the information and inputs provided to it to carry out the execution of the model.
The Invoker
class is what actually accepts inputs to the model/graph, and it updates its parent Tracer object with the appropriate information about respective inputs. On entering the invoker context with some input, the invoker leverages the model to pre-process and prepare the input to the model. Using the prepared inputs, it updates its Tracer object with a batched version of the input, the size of the batched input, and the current generation index. It also runs a ‘meta’ version of the input through the model’s meta_model. This updates the sizes/dtypes of all of the module’s Envoys inputs and outputs based on the characteristics of the input.
nnsight comes with an extension of a Tracer, RemoteTracer, which enables both local and remote execution.
- nnsight.contexts.check_for_dependencies(data: Any) Tuple[Any, bool] [source]#
Checks to see if there are any Proxies in data. If so, convert them to a Bridge Node, then a Lock Node in order to later get the value of the Bridge Node come execution.
- Parameters:
data (Any) – Data to check for Proxies.
- Returns:
Data with Proxies replaced with Bridge/Lock Nodes. bool: If there were any proxies in data.
- Return type:
Any
- nnsight.contexts.resolve_dependencies(data: Any) Any [source]#
Turn any dependencies (Locked Bridge Node) within data into their value. Executes the Bridge Node.
- Parameters:
data (Any) – Data to find and resolve dependencies within.
- Returns:
Data with dependencies converted to their value.
- Return type:
Any
- class nnsight.contexts.Tracer.Tracer(backend: Backend, model: NNsight, validate: bool = False, graph: Graph = None, bridge: Bridge = None, return_context: bool = False, **kwargs)[source]#
The Tracer class creates a
nnsight.tracing.Graph.Graph
around the ._model of annsight.models.NNsightModel.NNsight
which tracks and manages the operations performed on the inputs and outputs of said model.- _model#
nnsight Model object that ths context manager traces and executes.
- _graph#
Graph which traces operations performed on the input and output of modules’ Envoys are added and later executed.
- _args#
Positional arguments to be passed to function that executes the model.
- Type:
List[Any]
- _kwargs#
Keyword arguments to be passed to function that executes the model.
- Type:
Dict[str,Any]
- _invoker_inputs#
Inputs for each invocation of this Tracer.
- Type:
List[Any]
- edit_backend_execute() Graph [source]#
Should execute this object locally and return a result that can be handled by EditMixin objects.
- Returns:
Result containing data to return from a edit execution.
- Return type:
Any
- invoke(*inputs: Any, **kwargs) Invoker [source]#
Create an Invoker context dor a given input.
- Raises:
Exception – If an Invoker context is already open
- Returns:
Invoker.
- Return type:
- local_backend_execute() Graph [source]#
Should execute this object locally and return a result that can be handled by RemoteMixin objects.
- Returns:
Result containing data to return from a remote execution.
- Return type:
Any
- next(increment: int = 1) None [source]#
Increments call_iter of all module Envoys. Useful when doing iterative/generative runs.
- Parameters:
increment (int) – How many call_iter to increment at once. Defaults to 1.
- remote_backend_get_model_key() str [source]#
Should return the model_key used to specify which model to run on the remote service.
- Returns:
Model key.
- Return type:
str
- remote_backend_handle_result_value(value: Dict[str, Any]) None [source]#
Should handle postprocessed result from remote_backend_postprocess_result on return from remote service.
- Parameters:
value (Any) – Result.
- remote_backend_postprocess_result(local_result: Graph) Dict[str, Any] [source]#
Should handle postprocessing the result from local_backend_execute.
For example moving tensors to cpu/detaching/etc.
- Parameters:
local_result (Any) – Local execution result.
- Returns:
Post processed local execution result.
- Return type:
Any
- class nnsight.contexts.Invoker.Invoker(tracer: Tracer, *inputs: Any, scan: bool = False, **kwargs)[source]#
An Invoker is meant to work in tandem with a
nnsight.contexts.Tracer.Tracer
to enter input and manage intervention tracing.- tracer#
Tracer object to enter input and manage context.
- inputs#
Initially entered inputs, then post-processed inputs from model’s ._prepare_inputs(…) method.
- Type:
tuple[Any]
- scan#
If to execute the model using FakeTensor in order to update the potential sizes/dtypes of all modules’ Envoys’ inputs/outputs as well as validate things work correctly. Scanning is not free computation wise so you may want to turn this to false when running in a loop. When making interventions, you made get shape errors if scan is false as it validates operations based on shapes so for looped calls where shapes are consistent, you may want to have scan=True for the first loop. Defaults to False.
- Type:
bool
- kwargs#
Keyword arguments passed to the model’s _prepare_inputs method.
- Type:
Dict[str,Any]
- scanning#
If currently scanning.
- Type:
bool
- class nnsight.contexts.GraphBasedContext.GlobalTracingContext[source]#
The Global Tracing Context handles adding tracing operations globally without reference to a given GraphBasedContext. There should only be one of these and that is GlobalTracingContext.GLOBAL_TRACING_CONTEXT. GlobalTracingContext.TORCH_HANDLER handles adding torch functions without reference to a given GraphBasedContext.
- static deregister() None [source]#
Deregister GraphBasedContext globally.
- Parameters:
graph_based_context (GraphBasedContext) – GraphBasedContext to deregister.
- static register(graph_based_context: GraphBasedContext) None [source]#
Register GraphBasedContext globally.
- Parameters:
graph_based_context (GraphBasedContext) – GraphBasedContext to register.
- static try_deregister(graph_based_context: GraphBasedContext) bool [source]#
Attempts to deregister a Graph globally. Will not if graph_based_context does not have the same Graph as the currently registered one.
- Parameters:
graph_based_context (GraphBasedContext) – GraphBasedContext to deregister.
- Returns:
True if deregistering ws successful, False otherwise.
- Return type:
bool
- static try_register(graph_based_context: GraphBasedContext) bool [source]#
Attempts to register a Graph globally.] Will not if one is already registered.
- Parameters:
graph_based_context (GraphBasedContext) – GraphBasedContext to register.
- Returns:
True if registering ws successful, False otherwise.
- Return type:
bool
- class nnsight.contexts.GraphBasedContext.GraphBasedContext(backend: Backend, graph: Graph = None, bridge: Bridge = None, **kwargs)[source]#
- apply(target: Callable, *args, validate: bool = None, **kwargs) InterventionProxy [source]#
Helper method to directly add a function to the intervention graph.
- Parameters:
target (Callable) – Function to apply
validate (bool) – If to try and run this operation in FakeMode to test it out and scan it.
- Returns:
Proxy of applying that function.
- Return type:
- bool(*args, **kwargs) InterventionProxy [source]#
NNsight helper method to create a traceable bool.
- bridge_backend_handle(bridge: Bridge) None [source]#
Should add self to the current Bridge in some capacity.
- Parameters:
bridge (Bridge) – Current Bridge.
- bytearray(*args, **kwargs) InterventionProxy [source]#
NNsight helper method to create a traceable bytearray.
- bytes(*args, **kwargs) InterventionProxy [source]#
NNsight helper method to create a traceable bytes.
- complex(*args, **kwargs) InterventionProxy [source]#
NNsight helper method to create a traceable complex number.
- cond(condition: InterventionProxy | Any) Conditional [source]#
- Entrypoint to the Conditional context.
Takes in a condition argument which acts as the dependency of the Conditional node in the Intervention graph. The condition is evaluated as a boolean, and if True, executes all the interventions defined within the body of the conditional context.
- Parameters:
condition (Union[InterventionProxy, Any]) – Dependency of the Conditional Node.
- Returns:
Conditional context object.
- Return type:
Example
- Setup:
Ex 1: The .save() on the model output will only be executed if the condition passed to tracer.cond() is evaluated to True.
Ex 2: The condition is on an InterventionProxy which creates in return an InterventionProxy
- dict(*args, **kwargs) InterventionProxy [source]#
NNsight helper method to create a traceable dictionary.
- exit() InterventionProxy [source]#
Exits the execution of a sequential intervention graph.
- Returns:
Proxy of the EarlyStopProtocol node.
- Return type:
- float(*args, **kwargs) InterventionProxy [source]#
NNsight helper method to create a traceable float.
- int(*args, **kwargs) InterventionProxy [source]#
NNsight helper method to create a traceable int.
- list(*args, **kwargs) InterventionProxy [source]#
NNsight helper method to create a traceable list.
- log(*data: Any) None [source]#
Adds a node via .apply to print the value of a Node.
- Parameters:
data (Any) – Data to print.
- set(*args, **kwargs) InterventionProxy [source]#
NNsight helper method to create a traceable set.
- str(*args, **kwargs) InterventionProxy [source]#
NNsight helper method to create a traceable string.
- tuple(*args, **kwargs) InterventionProxy [source]#
NNsight helper method to create a traceable tuple.
- class nnsight.contexts.Conditional.Conditional(graph: Graph, condition: 'InterventionProxy' | Any)[source]#
A context defined by a boolean condition, upon which the execution of all nodes defined from within is contingent.
- _condition#
Condition.
- Type:
Union[InterventionProxy, Any]
- class nnsight.contexts.Conditional.ConditionalManager[source]#
A Graph attachement that manages the Conditional contexts defined within an Intervention Graph.
- _conditional_dict#
Mapping of ConditionalProtocol node name to Conditional context.
- Type:
Dict[str, Node]
- _conditioned_nodes_dict#
Mapping of ConditionalProtocol node name to all the Nodes conditiones by it.
- Type:
Dict[str, Set[Node]]
- _conditional_stack#
Stack of visited Conditional contexts’ ConditonalProtocol nodes.
- Type:
Dict
- add_conditioned_node(node: Node) None [source]#
Adding a Node to the set of conditioned nodes by the current Conditonal context.
- Parameters:
node (-) – A node conditioned by the latest Conditional context.
- get(key: str) Conditional [source]#
Returns a ConditionalProtocol node.
- Parameters:
key (str) – ConditionalProtocol node name.
- Returns:
ConditionalProtocol node.
- Return type:
- is_node_conditioned(node: Node) bool [source]#
Returns True if the Node argument is conditioned by the current Conditional context.
- Parameters:
node (-) – Node.
- Returns:
Whether the Node is conditioned.
- Return type:
bool
- peek() 'Node' | None [source]#
Gets the current Conditional context’s ConditionalProtocol node.
- Returns:
Lastest ConditonalProtocol node if the ConditionalManager stack is non-empty.
- Return type:
Optional[Node]
- class nnsight.contexts.session.Session.Session(backend: Backend, model: NNsight, *args, bridge: Bridge = None, **kwargs)[source]#
A Session is a root Collection that handles adding new Graphs and new Collections while in the session.
- bridge#
Bridge object which stores all Graphs added during the session and handles interaction between them
- Type:
- graph#
Root Graph where operations and values meant for access by all subsequent Graphs should be stored and referenced.
- Type:
- backend#
Backend for this context object.
- Type:
Backend
- iter(iterable: Iterable, **kwargs) Iterator [source]#
Creates an Iterator context to iteratively execute an intervention graph, with an update item at each iteration.
- Parameters:
iterable (-) – Data to iterate over.
return_context (-) – If True, returns the Iterator context. Default: False.
- Returns:
Iterator context.
- Return type:
Example
- Setup:
- Ex:
- local_backend_execute() Dict[int, Graph] [source]#
Should execute this object locally and return a result that can be handled by RemoteMixin objects.
- Returns:
Result containing data to return from a remote execution.
- Return type:
Any
- remote_backend_get_model_key() str [source]#
Should return the model_key used to specify which model to run on the remote service.
- Returns:
Model key.
- Return type:
str
- remote_backend_handle_result_value(value: Dict[int, Dict[str, Any]])[source]#
Should handle postprocessed result from remote_backend_postprocess_result on return from remote service.
- Parameters:
value (Any) – Result.
- remote_backend_postprocess_result(local_result: Dict[int, Graph])[source]#
Should handle postprocessing the result from local_backend_execute.
For example moving tensors to cpu/detaching/etc.
- Parameters:
local_result (Any) – Local execution result.
- Returns:
Post processed local execution result.
- Return type:
Any
- class nnsight.contexts.session.Iterator.Iterator(data: Iterable, *args, return_context: bool = False, **kwargs)[source]#
Intervention loop context for iterative execution of an intervention graph.
- - data
Data to iterate over.
- Type:
Iterable
- - return_context
If True, returns the Iterator object upon entering the Iterator context.
- Type:
bool