
    9iIE                        d Z ddlZddlZddlmZ ddlmZmZmZ ddl	Z	ddl
mZ ddlmZ ddlmZ ddlmZ d	d
lmZmZ d	dlmZmZmZmZmZmZ d	dlmZ ddlmZ ddl m!Z! ddl"m#Z# ddl$m%Z% ddl&m'Z' ddl(m)Z) erddl*m+Z+  G d de,          Z- G d de,          Z. ee/d          Z0de	j1        de2fdZ3 G d de)          Z4dS )a
  
This module implements variable tracking for PyTorch optimizers during Dynamo tracing.

The OptimizerVariable class provides specialized handling for optimizer instances by:
- Optimizing the tracing of expensive optimizer initialization
- Managing optimizer state and parameter group tracking
- Handling tensor sources and guards for optimizer state tensors
- Supporting CUDA graph execution through static tensor address management
- Providing special handling for parameter gradients and optimizer state tensors

Key features include:
- Efficient initialization tracing via _init_group optimization
- Automatic marking of optimizer state tensors as static for CUDA graphs
- Proper source tracking for parameter groups, gradients, and state tensors
- Guard installation for optimizer state structure
- Support for both CPU and GPU tensor handling
- Cleanup of static tensor references via finalizers

The module integrates with Dynamo's broader tracing system while providing
optimizer-specific optimizations and safety guarantees.
    N)Iterable)AnyOptionalTYPE_CHECKING)TensorVariable)Source)getArtifactLogger)tree_map_only   )GuardBuilderinstall_guard)
AttrSourceConstDictKeySourceDictGetItemSourceGetItemSourceGlobalWeakRefSource
GradSource)GLOBAL_KEY_PREFIX   )VariableTracker)ConstantVariable)ConstDictVariable)ListVariable)GetAttrVariable)UserDefinedObjectVariable)InstructionTranslatorc                       e Zd ZdS )ArgMappingExceptionN__name__
__module____qualname__     T/var/www/icac/venv/lib/python3.11/site-packages/torch/_dynamo/variables/optimizer.pyr   r   8           Dr$   r   c                       e Zd ZdS )GuardInstallExceptionNr   r#   r$   r%   r(   r(   <   r&   r$   r(   
perf_hintsxreturnc                     ddl m} | j        re || j        j        d          }t
          j        j                            |           d u}|r%|j	        J |p|j	        
                    |           S |S dS )Nr   )get_managerFT)torch._inductor.cudagraph_treesr-   is_cudadeviceindextorch_dynamoutilsget_static_address_typecurrent_node_is_cuda_graph_recorded_tensor)r*   r-   manageris_static_addresss       r%   _is_static_for_cudagraphsr:   C   s    ;;;;;;y +ahne44!M/GGJJRVV 	%'333! J'FFqII
 %$ tr$   c                       e Zd Zdddhej        Z	 	 	 ddej        j        dee	e
ef                  deee                  dee	ej        ef                  de
ddf fdZd	d
dedee         de	eef         ddf
 fdZd	d
dedef fdZddZddZde
de
deee
         e	ee
f         f         fdZddZddZd	d
dej        defdZd	d
dee         de
dee
         de
ddfdZddZ xZ S )OptimizerVariablegrad_to_sourcetensor_to_sourcestatic_tensor_namesNvaluekwargsr+   c                      t                      j        |fi | || _        |pi | _        |pi | _        |pt                      | _        d S N)super__init__r@   r=   r>   setr?   )selfr@   r=   r?   r>   rA   	__class__s         r%   rE   zOptimizerVariable.__init__^   s\     	))&))),1
,2 0 6B#6#?#%%   r$   txr   nameargsr   c                    |dk    r,t          | j        d          s$t                                          ||||          S 	 |                     |           |                                   | j        |i |\  }} | j        j        |i |}|                     |           | 	                    |||||           dt          | j                   }|                    || j                   |                     |           t          j        |          S # t          t           f$ r
}	Y d}	~	nd}	~	ww xY wt                                          ||||          S )zVThis is an optimization to avoid tracing the very slow initialization of the optimizer_init_group__optimizer_N)hasattrr@   rD   call_methodgraph_break_if_pending_mutationmove_step_if_cpuget_python_argsrM   map_sources_and_install_guardsupdate_list_argsidstore_global_weakref_by_idcreate_finalizerr   creater   r(   )rG   rI   rJ   rK   rA   py_args	py_kwargsret_valmangled_name_rH   s             r%   rP   zOptimizerVariable.call_methodm   sy    =  4:}55 Cww**2tT6BBB44R888%%'''%9T%94%J6%J%J"0$*0'GYGG33B777%%b$KKK  ?bnn>>--lDJGGG%%b))) (.w777')>?    ww""2tT6:::s   CD D2-D2c                 <   |dv r.| j         sJ t          | |t          | j         |                    S |dk    rBddlm} | j        j        D ]}|d         D ]} ||d           |                     |           t                      	                    ||          S )	NrM   )sourceparam_groupsr   mark_static_addressparamsTguard)
r`   r   r   
decoratorsrc   r@   ra   _set_capturablerD   var_getattr)rG   rI   rJ   rc   groupprH   s         r%   ri   zOptimizerVariable.var_getattr   s     M"";"4jd6S6STTTT>!!8888880 7 7x 7 7A''666667   $$$ww""2t,,,r$   c           	         | j         j        D ]s}|d         D ]h}|j        j        }|j                            t          |          d           }|r0|                    |          rddlm	}  |dd| d| dg            itd S )	Nrd   r   )unimplementedz(optimizer: pending mutation on parameterz
variable: z, parameter: zSPending mutations on a parameter (e.g. due to using closure) require a graph break.)gb_typecontextexplanationhints)
r@   ra   outputside_effectsid_to_variablegetrV   has_pending_mutationexcrm   )rG   rI   grk   rs   variablerm   s          r%   rQ   z1OptimizerVariable.graph_break_if_pending_mutation   s    
 ( 	 	Ax[  !y5'6::2a55$GG  A A( K K 333333!M J GX G GA G G$y 	   	 	r$   c                     ddl m} dt          t          t          f         dt
          f fd} j        j        D ]} ||          rd|d<    j        ot           j        d          }|
                    t          j        | j        j        |                    }|j        D ]D}t          j        t!          j        d                    }t!          j        d          |j        |<   Ed S )	Nr   LazyVariableTrackerrj   r+   c                     d}d}|                      dg           D ]$}||j        p|j        z  }||j        j        vz  }%d| v o|o|S )NTrd   
capturable)ru   r/   is_xpur@   state)rj   all_uninitializedall_gpurk   rG   s       r%   safe_to_set_capturablezAOptimizerVariable._set_capturable.<locals>.safe_to_set_capturable   sl     $GYYx,, ? ?1900!Qdj.>%>>!!5(J->J7Jr$   Tr~   ra   ) r|   dictstrr   boolr@   ra   r`   r   realize_allr   builditemsr   _HashableTrackerr   rY   )	rG   rI   r|   r   rj   r`   param_groups_vtparam_group_vtkeys	   `        r%   rh   z!OptimizerVariable._set_capturable   s(   ))))))	K$sCx. 	KT 	K 	K 	K 	K 	K 	K Z, 	+ 	+E%%e,, +&*l#HDK!H!H-99!"dj&=vFF
 
 .3 	F 	FN#4 '55 C )9(?(E(EN %%		F 	Fr$   c                      dt           dt           f fdfd|D             }fd|                                D             }||fS )z9Get python values equivalent to the variable tracker argsargr+   c                    t          | t                    r(|                                 r|                                 S t          | t                    r	| j        sg S t          | t                    rjt          | j        t                    rPt          | j        j	        t                    r1| j        j	        j        dk    rj        j        | j        j                 S t          )Nra   )
isinstancer   is_python_constantas_python_constantr   r   r   r`   r   baser   memberr@   ra   r1   r   )r   rG   s    r%   map_argz2OptimizerVariable.get_python_args.<locals>.map_arg   s    #// 
AC4J4J4L4L 
A--///C.. Asy A	3 122Asz=99A sz
;;A JO*n<<z.sz/?@@%%r$   c                 &    g | ]} |          S r#   r#   ).0r   r   s     r%   
<listcomp>z5OptimizerVariable.get_python_args.<locals>.<listcomp>   s!    111SGGCLL111r$   c                 .    i | ]\  }}| |          S r#   r#   )r   kvr   s      r%   
<dictcomp>z5OptimizerVariable.get_python_args.<locals>.<dictcomp>   s'    ???1a???r$   )r   r   )rG   rK   rA   new_args
new_kwargsr   s   `    @r%   rS   z!OptimizerVariable.get_python_args   sw    
	& 	& 	& 	& 	& 	& 	& 	& 2111D111???????
##r$   c                     | j         j                                        D ]9\  }}d|v r0|d         j        r#|d                             |j                  |d<   :d S )Nstep)r@   r   r   is_cputor0   )rG   rk   r   s      r%   rR   z"OptimizerVariable.move_step_if_cpu   sb    
(..00 	; 	;HAu5=#7 %f 0 0 : :f	; 	;r$   c                 0   ddl m ddlm} i | _        i | _        dt          dd ffd}t          t          j	        || j
        j                   | j        ot          | j        d          }|                    t          j        || j
        j        |                    }| j        ot          | j        d	          }t          j        || j
        j        |          }|                                 |J |j        j                            |           t-          | j
        j        |j                  D ]\  }}	t1          |d
                   dk    r|d
         D ]}
|
j        d }t5          | j
        j                                                  D ]\  }}||
u r|} n|rW|                    t          j        || j
        j        |
         t9          |t;          ||                                          n|	                    |t?          j         d
                    }d}g }t-          |d
         |!                    |                    D ]\  }}|j        }|| j        |<   tE          |d          }|j        ;|| j        |j        <   tG          |j                  sd}|$                    |           htK          |&                    tN          j(                             |sKtR          *                    tV          j,                  r'd |D             }tR          -                    d|           t5          | j
        j        .                                          D ]\  }}t9          |t;          ||                    }|j        j                            |           t5          |.                                          D ]W\  }}t_          |t          j	                  r8|| j        vr/|| j        vr&t9          |t;          ||                    | j        |<   Xd S )Nr   rb   r   r{   r*   r+   c                 "     | d           d S )NTre   r#   )r*   rc   s    r%   mark_staticzEOptimizerVariable.map_sources_and_install_guards.<locals>.mark_static  s    ......r$   ra   r   rd   r   TgradFc                     g | ]	}|j         
S r#   )rJ   )r   srcs     r%   r   zDOptimizerVariable.map_sources_and_install_guards.<locals>.<listcomp>F  s    (N(N(Nc(N(N(Nr$   )zGrad tensors %s will be copied during cudagraphs execution.If using cudagraphs and the grad tensor addresses will be the same across runs, use torch._dynamo.decorators.mark_static_address to elide this copy.)0rg   rc   lazyr|   r=   r>   r   r
   r2   Tensorr@   r   r`   r   r   r   r   ra   realizerr   guard_on_key_orderaddzipr   lenr   	enumeratekeysr   r   getitem_constr   rY   unpack_var_sequencer   r:   appendr   
make_guardr   CONSTANT_MATCHperf_hint_logisEnabledForloggingDEBUGwarningvaluesr   )rG   rI   r|   r   params_groups_sourcer   state_sourcestate_vtrj   group_vtparam	key_indexir   	params_vt
all_staticnon_static_gradsrk   p_vtparam_sourcegrad_sourcenon_static_grad_namesidxr@   p_state_source	inner_idxr   rc   s                              @r%   rT   z0OptimizerVariable.map_sources_and_install_guards   su   444444------  "	/3 	/4 	/ 	/ 	/ 	/ 	/ 	/ 	elK1ABBB  ${Vz$+~/V/V-99!"dj&=?STT
 
 {Gz$+w'G'G"(TZ-=|LL 	'''
	$((666  #4:#:O<QRR 6	 6	OE8 5?##a''"8_ " "Ez-$(	$-dj.>.C.C.E.E$F$F & &DAq Ezz,-	 %  * % "/;; / 5$&$(J$4U$;$5(4(:<(S(S%& %&!" !"	 	 	 "E ..r3C3J83T3TUUIJ!uX	0M0Mb0Q0QRR W W4#{+7%a((  
 6%2=D'/4QV<< =%*
(//<<<!+"8"89T"U"UVVVV  	-"<"<W]"K"K 	(N(N=M(N(N(N%%%
 *   $DJ$4$;$;$=$=>> 	 	JC.0sCC N I(,,^<<< )%,,.. 9 9  	1q%,//!444!666/@&(:>9(U(U0 0D)!,	 	r$   tensor_valuec                    ddl m} || j        v rR ||d           | j        |         }| j                            |j                            |j                             n|| j        v r| j        |         }nn ||d           |	                    t          |          }t          |          }| j                            |j                            |j                             t          j        |||          S )z%Wrap state tensor in a TensorVariabler   rb   Tre   )rg   rc   r>   r?   r   rr   module_key_namerJ   r=   rW   r   r   r   r   )rG   rI   r   rc   r`   global_names         r%   wrap_tensorzOptimizerVariable.wrap_tensora  s    	544444 4000D9999*<8F$(()B)B6;)O)OPPPPT000(6FF  D9999778I<XXK(55F$(()B)B6;)O)OPPP$Rv>>>r$   rZ   r[   c           	      ,   t          ||          D ]\  }}t          |t                    rt          |t                    s
J d            t	          |          D ]\  }}	|j        j                            |           t          |	t          j	                  r/|j
                            |                     ||	                     m|j        ot          |j        |          }
|j
                            t          j        ||	|
                     dS )z7Update the args and kwargs to the traced optimizer callz-py_arg should be a list in optimizer variableN)r   r   r   listr   rr   rs   mutationr2   r   r   r   r   r`   r   r   r   )rG   rI   rK   rA   rZ   r[   r   py_argr   valr`   s              r%   rU   z"OptimizerVariable.update_list_args}  s"    tW-- 	Q 	QKC#|,, 
Q!&$//  C   (// Q QFAsI*33C888!#u|44 Q	(()9)9"c)B)BCCCC!$!Lcj!0L0L	(()>r3)O)OPPPP	Q 	Qr$   c                     | j         | j        |j        j        dt          j        j        dd ffd}|j                            |           d S )Ngmr+   c                 D     d fd}t          j        |           d S )Nr+   c                      D ]x} j                             | d            j                            | d            j        rj                                         j        rj                                         yd S rC   )_bufferspop_parametersparams_flatclearparams_flat_unwrap_subclasses)rJ   r   names_to_deletetcs    r%   clear_static_tensor_refsz\OptimizerVariable.create_finalizer.<locals>.init_finalizer.<locals>.clear_static_tensor_refs  s    + A ADKOOD$///N&&tT222~ /,,...7 A8>>@@@A Ar$   r+   N)weakreffinalize)r   r   r   r   r@   s   ` r%   init_finalizerz:OptimizerVariable.create_finalizer.<locals>.init_finalizer  sP    A A A A A A A A U$<=====r$   )r?   r@   rr   tracing_contextr2   fxGraphModuleadd_graph_finalizer)rG   rI   r   r   r   r@   s      @@@r%   rX   z"OptimizerVariable.create_finalizer  sw    2
Y&
	>ux3 
	> 
	> 
	> 
	> 
	> 
	> 
	> 
	> 
	> 		%%n55555r$   )NNN)rI   r   r+   Nr   )!r    r!   r"   r   _nonvar_fieldsr2   optim	Optimizerr   r   r   r   rF   r   r   r   rE   r   r   rP   ri   rQ   rh   tuplerS   rR   rT   r   r   r   rU   rX   __classcell__)rH   s   @r%   r<   r<   V   s        
#	1	N ;?26AE@ @{$@ !c:o!67@ &c#h/	@
 #4f(<#=>@ @ 
@ @ @ @ @ @";#"; "; ?#	";
 S/)*"; 
"; "; "; "; "; ";H-5 -S -_ - - - - - -&   &F F F F@$$$'$	tCy$sCx.(	)$ $ $ $<; ; ; ;
e e e eN?)?9>?	? ? ? ?8Q#Q 'Q 	Q
 #Q Q 
Q Q Q Q,6 6 6 6 6 6 6 6r$   r<   )5__doc__r   r   collections.abcr   typingr   r   r   r2   torch._dynamo.variables.tensorr   torch._guardsr   torch._loggingr	   torch.utils._pytreer
   guardsr   r   r`   r   r   r   r   r   r   r4   r   r   r   constantr   dictsr   listsr   miscr   user_definedr   torch._dynamo.symbolic_convertr   	Exceptionr   r(   r    r   r   r   r:   r<   r#   r$   r%   <module>r     st   ,   $ $ $ $ $ $ / / / / / / / / / /  9 9 9 9 9 9             , , , , , , - - - - - - 0 0 0 0 0 0 0 0                & % % % % % ! ! ! ! ! ! & & & & & & $ $ $ $ $ $       ! ! ! ! ! ! 3 3 3 3 3 3  EDDDDDD	 	 	 	 	) 	 	 		 	 	 	 	I 	 	 	 "!(L99 $    &N6 N6 N6 N6 N61 N6 N6 N6 N6 N6r$   