From a628898793bc65d424607ac37ea4c5a692b46c22 Mon Sep 17 00:00:00 2001 From: Lothnic Date: Mon, 27 Apr 2026 17:48:48 +0530 Subject: [PATCH 1/6] feat: implement 4-bit NF4 quantization via bitsandbytes - Add get_linear_layer() factory in base.py that returns nn.Linear or bnb.nn.Linear4bit based on quantize flag (lazy import) - Add quantize: bool field to LlamaConfig - Swap all nn.Linear in Attention (q/k/v/o_proj) and MLP (gate/up/down_proj) to use the factory - Implement dual-path weight loading in weight_loader.py: - Standard path: batch load_state_dict(assign=True) for non-quantized - Quantized path: shard-by-shard per-parameter Params4bit loading - Preserve offline mode, FlashAttention, streaming, and stateless Sampler - Add --quantize/-q CLI flag to main.py - Add bitsandbytes>=0.49.2 to pyproject.toml - Fix .gitignore to catch __pycache__ at all directory depths - Remove tracked __pycache__ files --- .gitignore | 2 +- __pycache__/main.cpython-312.pyc | Bin 5498 -> 0 bytes engine/__pycache__/__init__.cpython-312.pyc | Bin 154 -> 0 bytes engine/__pycache__/generator.cpython-312.pyc | Bin 2287 -> 0 bytes engine/__pycache__/sampler.cpython-312.pyc | Bin 3354 -> 0 bytes main.py | 31 ++-- models/__pycache__/__init__.cpython-312.pyc | Bin 154 -> 0 bytes models/__pycache__/attention.cpython-312.pyc | Bin 6107 -> 0 bytes models/__pycache__/base.cpython-312.pyc | Bin 1281 -> 0 bytes models/__pycache__/llama.cpython-312.pyc | Bin 9560 -> 0 bytes models/__pycache__/qwen3.cpython-312.pyc | Bin 7111 -> 0 bytes .../__pycache__/weight_loader.cpython-312.pyc | Bin 6941 -> 0 bytes models/attention.py | 9 +- models/base.py | 13 ++ models/llama.py | 9 +- models/weight_loader.py | 158 +++++++++++++----- pyproject.toml | 1 + .../conftest.cpython-312-pytest-9.0.3.pyc | Bin 633 -> 0 bytes .../test_main.cpython-312-pytest-9.0.3.pyc | Bin 26296 -> 0 bytes .../test_sampler.cpython-312-pytest-9.0.3.pyc | Bin 19828 -> 0 bytes uv.lock | 18 ++ 21 files changed, 175 insertions(+), 66 deletions(-) delete mode 100644 __pycache__/main.cpython-312.pyc delete mode 100644 engine/__pycache__/__init__.cpython-312.pyc delete mode 100644 engine/__pycache__/generator.cpython-312.pyc delete mode 100644 engine/__pycache__/sampler.cpython-312.pyc delete mode 100644 models/__pycache__/__init__.cpython-312.pyc delete mode 100644 models/__pycache__/attention.cpython-312.pyc delete mode 100644 models/__pycache__/base.cpython-312.pyc delete mode 100644 models/__pycache__/llama.cpython-312.pyc delete mode 100644 models/__pycache__/qwen3.cpython-312.pyc delete mode 100644 models/__pycache__/weight_loader.cpython-312.pyc delete mode 100644 tests/__pycache__/conftest.cpython-312-pytest-9.0.3.pyc delete mode 100644 tests/__pycache__/test_main.cpython-312-pytest-9.0.3.pyc delete mode 100644 tests/__pycache__/test_sampler.cpython-312-pytest-9.0.3.pyc diff --git a/.gitignore b/.gitignore index 36b9821..dc1108d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ future_plans/ -__pycache__/ +**/__pycache__/ *.pyc \ No newline at end of file diff --git a/__pycache__/main.cpython-312.pyc b/__pycache__/main.cpython-312.pyc deleted file mode 100644 index 0d038944cd7c10690545eda6b750ad5dba7c93a5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5498 zcma)AS!^4}8J;DV_bnbGscRj}w&a+U9j8g-ICdQ!mK;k;Ej6iGrwf|15|`$M*`;j} zRKWr%svrg`7X_>!Ey6rSsrBH8KBz!|Bq-1pJpj^@D|({<>InoW3gvhx^3wjZ)RKx4 zBOZV|GygF&|IGjY^9}##a99zPpZ?%KV$h4w-)ToH>`vp;GK0`9gi#D(2J6yHj8SV{ zOsCfRm>z00BK1GbESc6Y)e+0FL8KLC#I*t_WGNC=g0eWs);|CWnPI5G-eftU%&& zj&e;?@k~Cg3Od!Mn4Nt-PU0D9y6F1$H6c4dKgT0Sqfb1mu!#bWEBZvfP_!j-*%@&* znv!x^g*m1;M@RV4?>x&-j*Ofin;2Gn(-+Q6OpQ!l7$3ecrBM^R_s#?zJ^VghRGm%U z*>r?bJIAAbuyx8{;8h+W`oFjnS_h>-y?=6}G zPNDwbMA3fr{P}S)D{_O$xD2jh!-A9`VqO+=*`hTPQT-_*;-W5+Df&hWv$NnkBk_d5 zsT^)>s9&$J@?2g}OjwwS7t*rACWUmq=!!&=A{HWYQp{cfB^W_gtP*&nz{{i{6uU=( z&1r1zaypl|B5`7t%N68&L56XS@dRB5@qi7a4L(J4B!Y#jVnTq?aIs@ZZF6#tBL&dQ za(&c!2f%mxi;hS{7BYDn1O+0@&Zy;Ir;9!<0=Ql5q6wsP`AB~D&wrxj-=~Y=i9#YR z6eLcf!2<@sQi&iXb7@?(!^Xs~g9b#PBqH#_V)uCbx|k_sxU6su7DO3zU9*ia6#A`- zIZkHtaUuzd^9-3SWFY1t2ql3i_Bh5oa8;V3mt>-t$0vt|&+}tLieq$aXqcZK9h-P= zY~rk992$OMY;agH%s|M?iXJkCVpWHwBS>(BMzu#I8f)C=(Sc+xBMhW-axyC>28IOb z3K%-@Jjtbmge(nQO{X(({sWn~n2qM=NH^VU`ssw9^e%kLXw%`}aP-z3y?46T9p5UC z)LH8W+g@Yam+Te%EAKnW>Ug5cwy(3e$*BhWmmgoX|0%PLP=XoM0sAwA3`4r6#FDi# zv?9GTQXQL7b!S@~*NU(@RXzKXO8(`M-U>L=z_5{o~ z7A(;`$#H~h@pvs{$YCJT%TqcApE9bOJc}>XOda)5&l{bubyiYqp<{O*sfE6=5qi27 zdU`GN%!0G->Z+Q%wp{)N2iXT&3HUWjGh3g4OnM8ovUSrSBsa~4?b*p@O?_wpwV9n; z^c6HCukgT-rg*>#>gB3qRM*GV^RnpK`TE z(pn`J8?bT3bbDw0^9Gq_6}7;cX>!?9fa=>rg6@i0vt-^Vw`=_7{+MLy6KW|jC4H$i zpLrI|TjUOni!EQhf3KroH19wL?UJ`@)_#F%Yf&x8g5<8|_;!tx>d|Pey18C!$tp27 zx=Ys7{uWnjX;Rp!!WU)^eT|$;m2>W~)UBSUj=4M1~_W~fO|{HxYDlqBiQr`*yILD8r6FWxm0pWv67?YE;>NU zQF2K>94t8?O}SgUfI~O{J|5P_1mJC)ci|2y2d8`tdd`v~71=X5ba4mn)chKE(%qZ) zlsusGTx;}_XYc4G`&aE|$Ta)&CNM%6FF{Y zz9@;pf{VHn9H9rRATY}b5Q_&i1y3M~rc_)5PnM$TWaAnw+NXD*B}~bIf?mjCii-+T zlF%^!h+f}{B`!%|R6Hy92OH>yJ`)tPieGsVlOxb2^mLV);5P<+160qVvKRI~j zJm0`#&!3qdRlMq|9USU)4PVDz8ECqauK{Z4% zHr{HAMe`Q|FF9BxY@6yQguZMQlL}^GRz1uHdmuH0mkX{9p0I<{MfAwT}=pVj)mEN}^Z$?(xcWuA6t^0<4I9?uJc;Yi7@^o*w zBQ$>B)^2nys|EgF%yJ__<4lg}bnO(C+%Y$1k zPkHnotfgx2txUi5(wi?;Cac!R*V%7u0?is(XQR+q=zDo`%j(?Lv6i;YQ1{Zc*Sj}D zky8yfUubEn5_s)*jCOJ>d;cwC0H}Uw`jFl}}VX(KQcVFmAcq77TTd zcfnW>b}cyS;hqK8mOpsw(#=bigR8!Gdw$(h{q~ECm;U1a?!s{0<=t@YyX)Hbc4W&P zSUPaiwP^gSo`hw*VSmkDxp*%)0O#fjEDDvj<;U;2 zkJjD(TlSmwrSIHx@86D~z(wY>qsZdCVXB8a-#GNzp~`{RqZ{F9EgW47A6;5Z zAqKvRndhnB@_aw1n0N?{1kWqFoJ9N}Ph)~4l8P(Ycr?$zqnS?@E)$AC6+5JutTdA& z8G%TR?Bi1dS&Cj0#Mz{*{_7(Uq9!Z1kcD@*5Zy@;Inqck&K>ejP^*bY3G609KIS2k zA#FC2R2b$a^nNE4e=BxT;^SB2Vmf{q{vM*>6zov{WWtJ7l*BB}n+bu?+eBgWaXAUE ze34}1un57c>HBY-PTz)*S*18CvKkD@E)#t z53hTVyb`!`pltlujBNWV!R2SxtWT9q_XnO@J@nr3^?|dCkKKCm=99Nh-#lG;a^3sr z@29KC{a3xWzW-qP*&m<#;kl}-r^a&i{v%rVK#e{4iP6mTYzGWXVCncaf~s=-Q(ArE eyudKeG22NUV(kmXyZX?5*8HOph%Pp)+WCKLhAznf diff --git a/engine/__pycache__/__init__.cpython-312.pyc b/engine/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 61e8cf14f3e35270183d3d951188efdd22ad7ea8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 154 zcmX@j%ge<81pK>SW`gL)AOanHW&w&!XQ*V*Wb|9fP{ah}eFmxdWvQQ$pPQUzU@To0*rXpPHARnU|^`9}nb|#K-FuRNmsS$<0qG a%}KQ@Vg;JS2*kx8#z$sGM#ds$APWF{>m{xL diff --git a/engine/__pycache__/generator.cpython-312.pyc b/engine/__pycache__/generator.cpython-312.pyc deleted file mode 100644 index bee5af9f51b828132f4261b114de10440c08b8ee..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2287 zcmZ`4TWl0n^v=xAzSwr#rL^4w%~}ZSR@eP4De@>R%4FHO-OhG)X1Oyh zv~Hm>C8-HU__6Q>Mibk##_CtY7bbpS{IJU!w_GJ@q95{OfPVOhiJm*NQ)(MeGUuLq z&OMKN?m1_^t*P-NfRp3DDSx{V`k8GUfV~AP2Vj{(GD;ztlX;a(aj*(1pBGXB$9OK) zl^0VYhj`SDWN`w??pfYePPs?;m}m6hQANwDc#nb44Ph;VBU!9rk~S4xi>SI@h)4Pq zRgIj*89k39<^YZunE@;hshC96w24XlA+wi^any4(oi6|%$Xbg8%e`rm&KnuqrHy^n z&58`f6v7CiMk!wQ$~+ciZq$QNKQ9XtV#>uMEan1hm2G;q%94Gu3wvY{du2EFRdwC6 zXF_yNr2HdX%ujvY_N{c1q#nZ#GvlO@8DmKi;9KbeG=+*>5lKMVP~>69t?q;ioZ;t? zDLNd-8JuD^EEqvETr%dOVqTZAN`0n2h&5#t6T*%VxFiCp%fMxzFV+y+bSQd8@_rsqD1ejJ;0y&qVC6NtJCM9&vYTANQ_ zkTiT=vaVtP-+tS*Y4n1^4)|8iLPFm_97`q2t)`CIv!Mu;I>$S;n5)*LBV}uDHN8){ zV9Wh<=TJH3bsgla?)fiswi?HnOGaa&S>s5Hyd&vEZoVWZhrybRisGC&TP+mF0_fMSa|PR6{YGBVBZCdwre#%N zR*OziZgw=BqI6;%^O{>6S+3Eco~}h#(VK-x=T;}g7;tjT_&RVbL(Jv=%4g0wX{wKf2nW0dhv()U6Tiwf~||eT}$op z&)a9(=TF@|Hr-xsf3s8)~V~k`9O^9xdLM^TDEgu){hDMr7{If1;Zp%P&AXUS|Bk%Gl&Ek&7O)d=N4jw zaRkr=ddv&r?)m6K=NIu61mq(>a<_~huZZE27%q!jAMh~WKVqi0#>|xSIb|(Vc76)iR$SC@W!ON4AhwG_5PCiHmeP^; zL+>4xilI@?Pi3AWLqfv_0qQ(OeQMeisAvkHaaGZv?TcRuq6$zTDA3N{ zku=59?gBSAvo|w4H}mc6em*o5L@|X+=jASG;Dw=oSjOIys z7^JcDUWU$iR8HgPc?NmVRV1_bknG*`xc2kD6(+$wp~75(u@Bvlw4AEwi*HFp(#({r zGQwz{kb763pBKlwfVIiPbcrNhUrZCRL8d<6oVhUNSU1ts2&nuB4`~U^8nO zx#_ovk-;g;oL*K{P0^KUtS>4$o;GO#F-hiD?1+@hsVibnWQEao+{yp>6sBtq0lAAF zhhr6?9zMC_Ke-PR$Y!z!eg&ZD{>txxDI<%aN_1q}pmHUb(jV=BZ2`Wova8-AD>I{1 zpX=3hdD*ks_P5IQc+9Dm%hfxL;lx^?#NTHg$BtC4u8YlBvK~urkVfp)QgAOk{NB53?^eE1yCcB6XTP`e{U0>K;*MV=hiC*$?m(q!fDUM7sLB7Yw9IUEYrV=^-Rfor zRgaa~t!^7~kl}@jKeZWDOYead_#$8QWUm0Gd_|wk(H8gtt<=-LyOusWI!Wmw9y0+F&PS?zESj9(aw|{==kfP5il5> zbFlY-u~}KsZbC~A%GgX{U52Kdd_uRzETELookl1e5<#1pwSA_MwlwKSgnEZoVi3#b zQ+dtIYc`is4ISIT0w#v(;?!rRB+G*KC1{eF#j==IRF#ZScm5vi9;42<95HU2c8K}} zO;Fk-JC*w>XBr-#BOO6&%rq{NuM0@5Ss}4?9>38ozk5uaMN{^7f;dmWjb-?R@T_=08 zfHo)A0lUiaI>73H*S<)o2(gZV*Aa`JInKz?4%9b_jLgd3Ev~b#ZXm0>t1=Y+eS_+6 z@!hrA9`D0_svyESBPHDyO{IVn@xR{v zSzMJ&D{dKCtjAr|xS^{nUGuk{TY&l0{#+vN9?x>yN|DX3wzH)d-5^`va=O6Hgongo z#`a|~&OtvQ$3Y?|fIwpq^t>jz{>+5nTp9>XCOxmcX)q^GFzf92;%MPS<0nH)dB2$A<@)@CZ>kAx$R&SX6{Sxs-G zKU5#6jo8Jf8YdqK$DXOtUMccOII$NUDbK9WRQZj|552q5H+MpB{(U%BzP5g?lHNG6 zJA9_}9Z+dTC+g9O>aB;&ZuG)V=z_DG=ipxF7e!stun6b0C~AhBS1BJ7#XEUPb-%RJ zcoXZ;1D&4fW}Q`Z%gMQ$Y`U$xbVA$eLB) zA9y{y_36SV;0M^&%TexP<<>72e!kE`z}7A~uWkHW^*8F5bjA&pMhRzu1DU9E|4Hh4 zGM(2`mSX6VYI}6u4nCC~JlQY0mrxfFaA1jUe str: def main(): args = parse_args() + + # Resolve local cache path for the actual requested model + # If cached, enable offline mode to prevent any network calls + _cached = try_to_load_from_cache(args.model_id, "config.json") + local_model_path = os.path.dirname(_cached) if isinstance(_cached, str) else None + + if local_model_path: + os.environ["HF_HUB_OFFLINE"] = "1" + os.environ["TRANSFORMERS_OFFLINE"] = "1" + else: + # Ensure offline mode is NOT set if the model isn't cached + os.environ.pop("HF_HUB_OFFLINE", None) + os.environ.pop("TRANSFORMERS_OFFLINE", None) + + model, config = load_hf_model(args.model_id, device=args.device, quantize=args.quantize) - model, config = load_hf_model(args.model_id, device=args.device) - - # Use local model path if available (from user's caching logic) - tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH or args.model_id) + # Use local model path if available (prevents network calls for tokenizer) + tokenizer = AutoTokenizer.from_pretrained(local_model_path or args.model_id) chat = [{"role": "user", "content": "Write a short story about a robot."}] prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) diff --git a/models/__pycache__/__init__.cpython-312.pyc b/models/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index f64a187e19990bd5ec03a4b8ec99f0aa8f3b901e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 154 zcmX@j%ge<81j~26%mmSoK?FMZ%mNgd&QQsq$>_I|p@<2{`wUX^%ThlhKQ~oBC%+^k zFEd%+CAB!aB)>pEpeR2pHMyi%zbq#wH#09&KQ})mHK$lVJ|4&^iI3MSsJz8t0~9RH ZNwq6t1)9VN#Kj=SM`lJw#v*1Q3jow4CQ<+Z diff --git a/models/__pycache__/attention.cpython-312.pyc b/models/__pycache__/attention.cpython-312.pyc deleted file mode 100644 index 7c0880f8fce5380ff5eacbf9da75b002bca99cd4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6107 zcmd5=YfK#16`q-$eX%bde&84zh;fpQ4H$49rj89Zu@%RrCQeGTiaj0g%)&1FfIG9; zkRfp-q%2h<;70{kBYFKRSVjt#D#cc-2CJ1q)k_uX20f-q&paeQj zg$ga5U_uOCplt~z40M7G*^51nkV9ZZD&{wqvRc7bx^K{83HEU+bk8!g*-Z^3pm}P$L#fv(802M$kMIR7c~bF6G5A15BH~&1DahK>gxX4-rS!`h|Teo+@Fkw zo4Z75Ku!%dpC+l87?!2x;dnd|O-7p&DM5@&&Acp&NjaKIHVuww9+HxIS>*cpc!ZS0 zf)4mg%}`yUR(<}--ih9+p4lUr-kh)Yyvj+K8^k*MLAZD$FSK$)++D>Uxn@Y&bFsd#M*=`-UVd;eKEmNq$q3mBU1-@-3si7QuW!mWar{G1w)1EH zxhl1Kuj*-BpcjtbaQ^buCHfcLOJ~%Br`6`3WlzsC+e0c~Rv7+fDVhMawimkIhm(p_ zh!&_2lc6HCK##jZHh}@l1uGV8<8;WLp)S%P2b{pJxjT)Q5k^%R#Z=vX==~`@L&?Sn zHDObz4`BRyv4;kgm9#aAGgkx-m$`sRC=`eL=ml&XU;Kswy*dOu*gXmdhPfs`tX|?W z1BcU?2OMZx8wzlkQpr`uV5ut&a+I9^5SdFzgXRTNY%zAR+IiZ529U>V^nXivN@xpY z|Mv3u{?FtwPx4s(nATyO0Bz=>nkfB#AOF7acI!o9;O&_bBbawl;NJTZa9Kf3@bxe9iM}&GQ*{#kFm9NBxy^m(R_~`Q{E4 zemgp_`-#=+omT>v1HkWF-Cm#H-l%SG#DVHBwp7nVr=wXmU$;-K+qcm0Nz;v{#r>bX zbo(W>_4IP@*?ezY?TxQ&Nx)psosHc2)}w0c(M932#O;LIdS*HFqkJf(hEgkA2EP;k zNB*n+g-xIAxUoa6JG98=J5H+|-;ieYiG~dOL7)Ke>+Kg&3bAyp1CS4}*mPhq1Nvdm zDCodqCaxO2g?Pb;7K~Ol5wbKsmqK5|8ID3oHs(;kr3`dY6nS@@R3?g9yRREH@31wh zu!_AjMq7R-R%!7Y)I58==5h9FB%X%`V6zgCZc8ONjD)sU#mKxZo(;7)?e9T1${2ZfMAU zq$@Elr!+@`mlNr@=9USclm=6hsIk!15127TWU>o+SeZoA8ke3YxFVADWxKEg1@La^{%#7je&?%-`+(UVI_jDn{2{0 z>0!-2#KlFxJitk8S-?D~aTf`%5yoU#3UQe~oFbxy!=u#}WIZ^7?^OB~RDfH7O_R!m zGTS*jbmi^KZ_l;N@wwW4xj@U$yFPj4#w$04PZPHix&0^Az{w1|>M5UU`SF1}Rn;@? z)9oLMS7MiAbDI|&H-|pGcb}~& z(2{F-F5hriZ8*G0=NgXu=H;cQetjxe^;)Lui$MA0$i&F(rflo{8y}y&c6KH3y?o%1 z8aM82UChS-Xt>bXRMMWFiP;`(e84$vXB_PLuhM4K3tpzc8 zGB2-FaXn%`)2UmtzvF%`#r;3s&kg6Zgf}U{hKQyxml%OfRwz*8^TRnYI~ps4vzwu^ zxfG?iEr`)mQO}rHHX&F64J#hSYrGMi!#U4HFF#iwV*j&9Nld<7r#FZiOM zdk}))8?1y#xSxkm_!;B~V)ZOk8ao^nFA&s??kva)NTDCp<4iqa);$eDLA44Z4vM@C zv9v6bxX2HS1id491qaI@K$W9?=~P-Gd!Qp2(4BB8lYVTCqLFpJa528k%acMM#_mP*SK!i?@)tg7H;oRJi{ zz!PD#!g4RguH7(4`YSkL_zY0JJ72v=t===?%h)n+Om(h$%O_u*cy-p7^FEcKAN6$C z>zh0_acrt>);=p=dFS#wbMnVy*T!<*m-61@s`vO3z0{KPp29`eG3IOcskIm(H{NK> zc{}pn6RP*bQp*yb^Pc{%SnpO;&TO6DI$JrtJ=>e*uF^N3UaY_MY`*=3+J0iGRc-H9 zE54s`fODHUJbgIZl8xu8+9q6Z$}@KY+tHC`cV(qqV9x}bp))OaJ(bIzop;JN&$y@E zvoB;@F1KfgW}D}RR?1s4oe%53yJhQa#ncZnWpKv4w^sGmVh+%%*0tWGm%X()Z@Ye` zhYj8JSLXfos=q#4pY55W=bl@5D(63tv90oH~(tciFX5_rQ=Gb_bo9 z4kr182vI${kZ9LuKPDoD97VA~?m%PiFwYo{0ish|0X`rEbw?#G%%>$j4l!p3 ztg6*WVLlGQqL7j?Llx3t*$Uo*jt$XZ;nw11)~~LF(iW~}K~)GBqdOivQqzNB-w|Mu zFqy@$uRI@UR0EB3{6fV-=i=_A`rE;kz^P@|DP5o4L6-awILVt>g|Na6u2BVTBRG(= zM%4*~9*&MvYYc7grVC6s9`qGrYM1`mwJDh-h!I+q@z@keC&P$Wd|WHDXk9_%MPzln zE+(ZE(Z5XU(5X32>BgnM@`4aNFj|6Xu7d^hJ4E_m01ZKU9jY&FG);d_dH+Oh{xem( zOx1o)dA?*B`kATcW)4mt{E`AP`|?L8ubx~(@+%ijzeeA8w9vIv3|`**6p;7(X@>5m Mr{4TK@Dxw_A1Elgg#Z8m diff --git a/models/__pycache__/base.cpython-312.pyc b/models/__pycache__/base.cpython-312.pyc deleted file mode 100644 index 9c4f2cfd2ec3ca4199db30d1278319da0b5bc04c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1281 zcmah|&yO256n^t7nS@PurMBWXGUd`pt3;~Q%W8$x76dEN2*I94NR~5x*-Y0nPHoRZ zRy`oJhjM~Fh2FTL{wZ8QoEQmFsV6w41TLJgXR~RyEW*+}fA7t+zt7Ll-*-A~u;an| zzh^)B0FRw>IvPu`*tK8=9JnVS_YyDn6W?>bFM=ef=P-|w$O9kV10Ea!57EQu&f4i# zPv?i>2~L_*zZ?DLoUUI5H*VglF0i3Cilrt;GnRaiK3xTHy4B(f8=V0{;&DiP?j-?+ zE#NT^&L^}|&K1HVY;ymw3BVf|pPz%xk^i#g8^yeN6g^z(Nm^EVtBP;2QZwE;$O4m)ejlp9svcFvN^AO|rtE8nG06WOs}Ni@mHsQlMQ%gp~U(LuVX&gQ}Gk zlhV+P>uPNxb!IYI)ZY4pX+!sMN)MPQQCAIxrc_0`Ea<9Ofkg`!oWa!V`s|tq&?d`q zL+$CUmFI#hzAbWf$V@ukBpUBhfkpRB-m_xZ>@hv(bc2wQC3fM4Ju5WAQ>as;uT+8njMdjPC zal1b9iO93Ut9bjQ9&fAH?Lyn_poul}!1KIc;jPE8_6uD8GrZ(=?(45_&UTJ>{sQaN G3jP5KBtcaG diff --git a/models/__pycache__/llama.cpython-312.pyc b/models/__pycache__/llama.cpython-312.pyc deleted file mode 100644 index f7662bec935db767a86322dccc9219571d7a8ed0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 9560 zcmc&)eQX;?cHdp@QcIB{B}jjKixu6e~$$taJkGz z7fn%flt}p~krtUS?W0M`_!yE}BCL<4^*S!X1{j)(Sbb%LVhQsRo6kn_Y`8pP_t{CF z3p*kezKV#`=cFmg#!#8mpKsgW;yD|P zPLg7a=VVZc%l>eS=k$aWeZIBB^LH~p@c6@F&zK^4;v+$&^(&m+fS1;6qCf5rg#C&F z1^%diO!0^NPkI^6b|fB`qVZrXsyTb1e#9&X^p(-$+%i!c{|e0_-s-+)H!n{J1MzH zsnX|=EMnzU8GHjPRv}HQ2)Q9C3O*Om8;giggjQQfPu0FksR~Lpq_p)0<=Zk@>8;i5 zWWkTdqC>%9t$aKd@DB<~a7xnbBSBG=qI#h!Iu;Rfgs^{7k`=8QNq%F6g%Qaw=7?9N zNnzX{9+Ps!s$euO$q`8m`r{JmFG%qerEQD&uM49wC5T%jNRdHF6ob)W<)7aBmp@*7 z^ykYjY7RM~2+^1v5u{ONn9MHur-!w2IW{T@@ewKR@70|7Zx{^v;S=IPNtlSqVnL4s z8js&6h{1@)iSfx%Nn-hi?0S_9=#D))|IVwSg%&L55rU~e|BbldC z$jug(_52iYCMdzIn4~7D`}Bae{A2!iZ_4&*9`G;Syy^T%a? zS4yr%0y`Lsh23ZZR@J20n(VfQG~1A^cc)o*F4if#o=&mb=g*|sr?U->GyS)k(`;k5 zu4(4PEiTPAWw(1~25ucov!1M{ImK?9hrXJ#+nZ*NCj+;R&4!SrzVS}uLgUOU$&Q=d zv)$Oymb#hV#}?PT{Nnf^Zhls_d#Hje;C`t@@&KvT{w1YnIX2gvLfBf0Qh?d;r?qXY?)TguH^C zcsyjmwX&g@JmHtcDVMS6Il|M(c?cpvCs*YUI_A1>ch6tDGqy0cc=`{n-l7r5($|=*m($z@dVcF2_kw$!!u(S62QBwo9+A8b57^BnctFv_uyF~ZP4EEV zFvJjGU~V{ktZ0$AQs)W)qa+=GfdgPDBf!AJ&RR9s8-UXCPo0(`f*$5Ne^KAhHR@ zI4U&t-QNhKpga|Y1QSAKRutanj9#ZL1-+u0D+&oq{tKqr0$$X6MIpf^*da!=zHi

zOV{8W5*^+JRLCP7YvqBMA_V+_5lPg_m0&cN_d4Y^?0}q0Vj~b(w*tyXhb7%`!h@^1 zWN8>~3z94hjtva~{1XcUhtas$7|=CKzz>)N!+_ZkjC#v-AnL{5k0To$2U`Tb4(h=Q z)lWF0DJTU9g?gc`=!}^90cWE88AOlihxt93hJ9+ozMpbGbNs}ya{6QIr>;+2snc(y z+s|a$FR1Mo((M;hSH#Sfhzfu05w#(zR>zWjmUm=$x5~TceM_Kgo?mlR-Rz(3pRdi- zA5`lPu3Z1Wp{G}~BIxi%(AQ<8iF z%M947D0fg#tPNa2`7B1KFnS##1$8(PD_wgWhT_z=VyVh0Ky0s^g9sF$bIZ-xY-~QP zI`?Iq9jdcq*`Id40Q>r_6|OBe$7jcH`R9gj5C7ix7cKX!cdd)p?oHgC_>n8+YD@8L zWS{TbighBcTwWc@N3eJpqn9x{3Q?i{dj*LY?Zb*^AOiij&hT94{M%nskgV5OxOPMD zL3a7gpI!RNrAMF>+jT+1#{X|x&vLT=bs3>xlbXX_BEo$f_PfXm2KBMo3XUT;2_dVI z%crG9v9y{J>Gv5!$$|ZA_V#}^f@w28afA{Ix%AA6=3pK}xba#Bdhv#f*L1I~0NtO1 z)zYl}!Kmby^~*3`aWva7`U^(kwyTwmjp`Yj7@LTaEL@(+A;Z1XAsaVrGh91hb4`fu zfokO%M4KUct$zER`i1(%&djb3byvsoc;?_K_28-0TNhXBFM)CIT-#PZH+g$`jXlWNDw)Y%KG+kC*_DCnVE-POJFR_4$t^^n=ad(J$b$_HWmO_GK?KqT#5 z$PgGsgeeMfxN{p8d9%?oXOgZE7C#>L5lz^c)`;Ft)8$K>CD zlI+LmC5Q^`cmfeqc3klsd?8#YsJ_vDufF-t>^pzxzcaKjw0QR3mAhA#kEEMB(~SpL zc7N3JVavzY{-jN9>`S{&r1%s1dR@R#j$$9UA~|dOA`&s`#tJkL2->jRWg6iI({%^U zy;fjozgA#qzm~g)V-ErNHv?)j3>)~p#fJ8#Tet&m;cySXLHn%4crGKxbLCkuOB9%R zM3P?)#{yRgrl6+^ZEAEyZMtQf%JUl-t_WWqyJhFkCY(4E6Yy*GQ2HDiW*y*&iBECU za8A&9Q?u}<{l6Z|JnwOI>5!0@Fc`_i5#vs8F8~5mAvypM2T^%2Iyx374nP^_t+C+i zgPT5aZlPO{NP=dGgh!1FK0$*K?1tRaaGgQh7MgJN#XDsJ#H+xppkx#QM8l57J(-;c z)SU;GhccagYG+@%;RGCf*ILca#m=R!bWJCeY-{zNLX0@Bb{tRF_YnlCuDvz>``zFX z;E8Al3~7V0@(!1uLIK?X(<~g+5+IOY#%UTjOh~{^^njb1FmFkh!2X`5;a0#VXuRY@ zE4^=19E^5w={Ma7z=h+RilG#b(OZH;t4MdvuyP1k<*OKlF^WJ0Zh`HPXcdMp8Qg&2 zlqnnagLm@kII)d1UY3-g2yR2Id=#?av4<)J_ZxN&LV=-F_t@U$pk@HLiGvuRrrHCx zbDjH+Yr(Ztk*@1ZR%CbYz1MfQZz-~J?&ENJ_ZgMnk*v5im96tE9#ZStlNDbYZJZ0v z#p!fi2eg681NUCI`@+)6^sX+ot_w&m_k1Yj+MD9{l5GTYaKOu#w4Ppp@6k=D!NmNPiBfn$7f0$+*)* zQ)mv*t&%9jV^_gzM=l~V;=fjYlDtavgKI$Jbf1A88{rw!tokDs@!RQ-5iHRs_(Y?X zh2gCWo{RLChyy?>)Git=e7kNbK2oM`LA9d%Ekv7jOSaYx3c|kTZoJdK(7#lh@ph@+ zu9e^}mJDTumXwSd@TN%|n zNfYG$r01z)t^6iEhk#D7gaIr8J!hGz7pdU_r#Hw6)D#Ef{Gf3Po&shjoUn*AdKG5i zl|5Qxlyq7t^H?N4h)w=9(@4pWhW<1-8HJiTS1DR>G}L)Ljzx#bJD`l-5qScMHVGc2B^eA1&FUW= zm7=2jHdZDe^4fJf#HJgR;L6fX50`F;81L1J#uEaL9K$}hm* z9E>XpVhb@>^r+Bic?=YCM#~U1QL4bw1Ea;>kh1U0Htm`p`mwTf;Dhe_-RY+fs*MMe zr?U2KDSK1ay>ICVoZOVV^Y*osQ=c6Hc=+s^+H*S7b589!mpboLdoCuAuWk3-=~?Jm z;+OZXT=+=%Q26XL+(gyqPpzK2lsV^D&-qg!B@>!bLsP5U-+n|Vaokq`bAv;c@{W#hW>k_gG8 z+Dh&j-KIh|_#H%OCmqqAG zvztFhkrWw^nP%7D=UWH;iX&sDiET($h>4P z7R`rmjqiaT;SK&p`G-(}mj?y?3F{V`rhiVo_zTMZ3(D~Ywe^=&Q;KT(g4*!~<@qJm zxo)Kx`y*xx-AAue7;jjBuucdMpDCj+(8)pcBtN7u|FFhMx8E9^3*H6-WFLBX`ZYSa m^QL#!`;fx?!`dp^drL-l$wLaV5BIw1XKr2pnu3IU?SBJ_Hgd}V diff --git a/models/__pycache__/qwen3.cpython-312.pyc b/models/__pycache__/qwen3.cpython-312.pyc deleted file mode 100644 index 69214cd4f051b2f8894c4345ce8bbe677a234802..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7111 zcmcgxTWlLwdY%~$$ssu;br)^R7A0G@=!;}qwv#BzCXT)GC9)jHO|;88p3>wjC{XQ7n>?+6zHmyV1FZO z=opiv!!${mFhf$C#6}sKlI&pzZD=?pHq4T|Ey+o)uuF1>-I6Elk@zq#dBa}G7xs~Q zHtCm|!p%}kxK(Nkx6u?sougv*_o^c8kH9&+fOF|0U~PEE$Ablat|H}xhW zfyQSdb4nyRc1dsgJvow6CZU-qzn)A-Z|cssE?pgmvhIH?tw!X9A4@mHSS+5J3fXkm zX;l?dYCN6NeWS^UGF>b4k(rs~f*_MNQRB;un5`C`9)-*w!%~YBEHQ=4E81d=$WF33 zYLbrG-si*im_5e6Plp{j>JRC#^C|_O;l&TF4djR#ALKTM5=7}nw7QaHARG?MPyZUC zKciPDm54D~vSp|bVc#=!hIzyoVwzITlK}FB!&0^~AmyxrmHC7=d&p3N)x-ZnU)Bs3 zs1P$AqIH`%qlBEgU74E^Wt|g*cq*<6g6xF3=-jjzi3zc|r1P?*ko^}xUw6!syzaP3 z(ub_hDq?a{_CX`rk1bt-5QQyN1VOuV zIh_!rsxmmAOiD2Npp=e@No8;rPvF4Jg4_xnEOuUZ|*fp0I#xKsGxE(iPl5 zls1HGppj*>_N*gACF+k#RWCz7qK$0U2^!XB*1Di+>ffaFYC+re%r}U z^qK8z>7#1)drs)@d;=F@E~epy`A~@iwQ|{|W)+D>N}*{pYve zek)@Rp&WoZ)CN#K9hngkqh|nF5u=TC-bmT`xOhugIg4wgYNqJdgy}We+Ob)wCR))@{*vhU&h0vd#B^En3xh2HKK}&?wpAzDojP^VPbJ420}ch z=(ZcmqHdoRl444y1)ZMN>6$^L;GHKVA zZ$T;f4=u4dI(G~@BM$@s}>F3ES9i92+d++4j zWVy`Gy{!S!@54{iW|SU-8P@Iq}BpQrP; zpHP^t`M)OVH|_)WZZav-APTxxFv-miCChDiIT|^AjrDEMl>Y;U|G#6(ror06lX0a; zsORFKVNt6e#-!`UvN*eHk%t;dgOMyVus}Ve$3q^u1?FSyI8H+Mm`E326lG)=x;vhl znNta$CL5=-0GTP>1<+TO;_m4-DLEqtAZPJvVkIaLTTqkV9Wl^HmF%WSZX^7CD-PTV4Pu$O^SJ%+jENZ5jgZC{Uql z4Iw&o)(-TSCiID}_FXmf^^LdmjpJLVp{7;SF?-9{(dfz-K_|q@W01&~FdE0`GD;o9 zcFD(3a2%r{h;$AtPdo=}DTu)2v<7orCD8T2 zf8W3AEd@^Hyp?_ZA7A+B!m6}3@oBQO@2bZ2GKfBBvdr|>6Bo#SFO{qoj(q)pLL_n;eguH+qzcOKDXda zt@mwQrg{uXARTK6=A6$`cbOt{EeS-0hF~JrGV{Ss7TRSQvoDyP zg_7!Sv)T4v1X46P4Z-Xida}N$ZWVhQ)S>v?-nUs>a0wXmmcs}LpCmvZ`^|hUmG{H< zZG05CHG3l-Q9|u?Fpi_?ObQU%dc)G) z^XX{hhM*v*zk&5UydQ}%K~3KjQ;P1sL@2>nTv2s9@a-v4Zzeq4V7++)Ix4;Sm{%k zBjf^&2>{ZaI_nW=xM7yzK)_SV5N$CW8(gr`(M4WVy7oL6yFa$tQ4S4jq2aaB^2oS0 zGG6MsoMSh*9h)726=_-8xE3kim@eN~)NU+pT)SP`m(i}>E>8Spqwl4){h#jp<=|(7 zpZC30ymnjLmnn5*Nylxw@}uRh0j+D`v7>zSjCS-)sqHK=M4dZU7M2(4jnSxfa&-Oe zjn1o3&sRLnWl!JVdHO27VAqDU-l1~uu+}@AJ9~F*Y3y^p`-zR>_y60!I;p-aP0;zX zBfHObb6d<7f;IsZ2Cg)iE!&+PF=~odEyD#z$pR42tp@Xop@)~Xf&ZV(*kZ(6hL-?3 z)n=9qSps*WA;cKmvMgJ7owx=SD1y-gv$t(~^74QtOvP^?;dckHPgwIEVS69h{Jty$ z-wzMpjiF{lv-T+}>!^db?_IC6)VpxC&9@5%URc0qw`H8T7VxB6H6Enio^@p$iS1rx zT;}X(>Vtk3FS`YjwRvXjrWQ~KDuDwuggOm(#+`8x&)YsuTx~fB$}%Eih|%pabpgHY zm^dGgiXpFV1GiLu7c~$8<%1CEj)-yF*&83vtIZ*_p_$euifCQopelKXx z%!sL&d>z~W0YpF_af-UjporiD8yv%L_{ZXcFds?Ify2xZ0Zstfh~8qkzcm#)uGeMQetBc!T6rR(O+<glTd1m-Dy3rXcyIt;v4SpjfCD3g}Kjh^3?8YqC;Aoi+|#D|w|ENTmiQ z=TcF8JpnTM4(RLdnn0&Vbqo=86UA!U3Od6MBq3=jH8@_SH^n5&>oePW|& z`d3u%uPM*3DgFzpyGV8a1J(WoHC3jjzMv$HlAf_g=hUsI5h!Tt-Jh2PHr diff --git a/models/__pycache__/weight_loader.cpython-312.pyc b/models/__pycache__/weight_loader.cpython-312.pyc deleted file mode 100644 index d7ffca3dca5677a4e60ad0db05af4b023b66eaeb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6941 zcmbVReM}rjmam@gZ)OH&h8bXH0DoZ{2nKA&S?ry;5 zdJ=XOX(tz32HokH-DI__w9;}ie>l2bh5PHATPeGz`{QOjlQ8ZOU82iB`QuJx?Ug0p zNq4WhDH=Q8wUi!F?^V^SSFc|Es_IqoZ+5!{!MEJ=-|=605&BP@A)j%Bc_dH>Eg%Ll zR1%FUlp3Yvyk=AbRFl-EbfY>7*J+dblws76GL9PMye?@hBsM!ism5^aja)|B-UU8A)O zg@2?KjHs4z`$3pNa#}6hVwUZwA+0EMlafn3TT5= zqVX6#K6#a9GSg|St3@yAlUR*pOvQPDizFreb!fFuGRlPRR`Q*aE;`9XBvUHGut_@3 zNCt+Tibq*V&j>ds*r;}khn&FU{G+FUss)5`qkIe{qW1tNA_b{d)WEzLR38W%vw=`T zM_Q;fL2#%RAx1}9GkWaltX43SY7sF8QM+usyJ1;YCs;S@bc~72SvSh+1v{xpI0&aw zQBQiw8UWewBW)=pxQN_a2F8r%Tr@D24Qkz>Hif=$1a^>@>xoXtJN|-v6|;$2(eOfz zF2=3udZ9!$Q}YrD5957_gzu#iwZv=7{<}8COIf3e1sG=lkCUmxDUqeXe=oGTo@pQ! zWlf?G>gikFc+}#a^^qe&d*bw<+LZS%}8Wzof&`pf}m{K*0 zCM=w_2z47xP@tAAPA}FlI~f!}vTQ4nP${#E zY1y;^>wA&RZf4I*W%e?yTV$#e?P?!{6E}Bn(IwU#hT>as9eUILmNn5qWK=4;7%$Vd z+)h@$JZ?|cD;g8~R4Kx7+eK@li{wqK|ljr*65J4zTz*sa5F~%=M|Igk$QU)d0z`xz@jGbZ>0n)Qpdi8G{~N|{d4&v=)+i05H{-1JT21-3#@LuzZniT=bn z!mAXl?$!y&g*$YHMrDB~}2l2)Ji^8>dc^v~h}x2H z#BYBYxzGCL&-CzD(8Fs)uS&m54|CM5eL9T7%c>s25q+>P?19;A5E~Msq)w%PLFX3V zy`k0+PHfn^6NwG<=Gwj0D)${Bk}7?!@1Pj`j(rE8={xwXzBg@A`xYrtvwV=uGptb$ zw@vi)p*!xVW(-9&S0H-Gx|u_wdy0b5n?(1#=9X25vZ1V73~>(i^kNQ+p=YB?(fdsO z5z+f>{hKc6-AHf77cetX)iS~n@4(90jDr&lsa0tL_VedIwXJ9sjWNwIdOo6@)7`qP z#FXEv(cU^rYJ@P+ybW=G2bOG&skS|HDmk{qRF3Z$Q&l~G4yJ9-TbI@BW&TPK{8J5?#dsb}lfv5L;EW8I22q7`EFy*h~2=d`ze zH;ohc04^6mgDB|j4mu>!ZWt;bX*t@%aLm7@0@Ac8BS+}Fh4mlk>LcEX$ub%Ru)Tm9_Da z7z>`|GaTPSaaCAr{kLk`;%SDxAxdBk3pEJ}gQ*TSp_Tjo`jGRrks=ptMm4RG0uM+M0c;raMj8s8Q#WP}KJ zv!6`n#tyUo>esTHO!4#-9pl*RAZh_u@o-%P?6hb&eBG1B`1Sw62i01bD9IAd@N_g1 z9cLNI!pGAaJh%hPV=&;}@FdTe&&I)s^w>f0`$(F-0oO@z^1*ORm!zNI0DY1s!%Ids zJr(CNX-PZA3Y-xu(E*^+k`AvDNr$hq91d?J$Jz6J{X_JH{*$MMFJ8C|0+~1n0Iqq- zatR~5pW`x|q!%&(2ibHE-x4Kr|BWa+fiW#eMt|{Nr4U?Oc!{( zhUKN1j6_*UBV?p1`Fakuuud58u8cg%CIgukGS>hRl0lhi$(V!{38{p z@hsfMGt6X?g`VQz7toiEvNE1wy0}Psj8)LAk3oNeRK>AlcrrNp>SQd&a>^Qzy{ykn zLb;@kf|wR21cn8kuuOm!fGgexjB%dE$dxQ?YC^b4C*#*xNe82a7Xd{bFX>qD1BdU( zl2wtC=d~rw;nPgJhVux^xcz06B`emug#n*=PM3BVh5O`@hGeVU!7ugMP^7+sQO zauZvMoAE*eH*=lX97>S%7~bF`m@fbx_K{>zd;nV-3%e~mkzh@H7#f6M@;&WsAJ3%N z_GCsFPsgL}eJp=X$V{||H*tg_eGTpPSjI5k2So+Gk{;yos3xWB-LU2A<2)|Nt zb=(^+x_V|$uelqScCB0~Tp267$7j!cWwz(8rmbID?LbEU>Rmo3(8uLSA~4X+df9hLfEA=Fi@?*`@<13w)oH217D6^>je z9vm*!UsT1yO}{q$(oooW^!~O|_+-W7TeviTDSvLMujJ_fhKJ6&{6NXsoYP}31=mex z{h(&ew`1u>;rK{loGbhIoNdh$%!~I1%AP|x%X$rRIA;ecRsPkgP_ZiXX;tGp+8mW@ ztT@^wN1%d(N=H(0T`S8Yv2Thrpfr5pGA3$~`R?erQ9 z-ZpQ`Us`QARBSkOzvuV8AN3YSe^_=#a>v$c{0nF1&n!}lJ*Aqxx&AdUvt_;|KeE`L zKX*@4c6a7Zs9F!-tNzgUf$tyvAJ%_RzrvO}`U>B_B=_>#%Y{C*y|(T`)lKWQ$n9Ho zH5OfspSl1FFpGDNynAHLA6)gf6#Xqr+gA)_|EszF+o#v{$h~ddfU0Y;@BSWqSFyV8 z^uF7>76`2db{7M?m-eshCoV7M}xUb`R>_4K=qeK$Cu5!OTNBe z9>_DS z?f$~#_@X8-JMi||iq8C#(?2@>_7H3zm|4{Yin>5qR}Vv&+mSm`(A7RUkC%VZ9D4es z2M}H~H~i0I2eNtohzuC&o^4C77mog*aQ&~#9x-Qujbm@k*Vn@75}r9<9}0e#tJ2EH87Uu-s) zL&w~Z`J{`2R5H;Jhe5=ez-~d0y zcXBu)l&Z!j$Hw3kib2$Y-}2>9h{LZ)lIf}(5p?W_b6((-zzm-octrTtK-0>JeGqqo zZ(I?E?d8 torch.Tensor: x1, x2 = x.chunk(2, dim=-1) @@ -19,10 +20,10 @@ def __init__(self, config, rotary_emb): self.head_dim = config.head_dim self.hidden_size = config.hidden_size - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.q_proj = get_linear_layer(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias, quantize=config.quantize) + self.k_proj = get_linear_layer(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias, quantize=config.quantize) + self.v_proj = get_linear_layer(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias, quantize=config.quantize) + self.o_proj = get_linear_layer(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias, quantize=config.quantize) self.rotary_emb = rotary_emb def core_attention(self, q, k, v, q_len, kv_len): diff --git a/models/base.py b/models/base.py index 7e20ead..6481346 100644 --- a/models/base.py +++ b/models/base.py @@ -3,6 +3,19 @@ import torch import torch.nn as nn + +def get_linear_layer(in_features: int, out_features: int, bias: bool, quantize: bool = False): + """Factory that returns nn.Linear or bnb.nn.Linear4bit depending on quantize flag.""" + if quantize: + import bitsandbytes as bnb + return bnb.nn.Linear4bit( + in_features, out_features, bias=bias, + compute_dtype=torch.bfloat16, + quant_type="nf4", + ) + return nn.Linear(in_features, out_features, bias=bias) + + class CausalLM(ABC, nn.Module): """Every model must implement this interface. The engine never looks inside.""" diff --git a/models/llama.py b/models/llama.py index 84150f1..f8f1cf5 100644 --- a/models/llama.py +++ b/models/llama.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass -from models.base import CausalLM +from models.base import CausalLM, get_linear_layer from models.attention import Attention, FlashAttention @@ -24,6 +24,7 @@ class LlamaConfig: head_dim: int | None = None dtype: torch.dtype = torch.bfloat16 device: str = "cuda" + quantize: bool = False def __post_init__(self): if self.head_dim is None: @@ -59,9 +60,9 @@ def forward(self, x: torch.Tensor, position_ids: torch.Tensor): class MLP(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() - self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) - self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) - self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.gate_proj = get_linear_layer(config.hidden_size, config.intermediate_size, bias=False, quantize=config.quantize) + self.up_proj = get_linear_layer(config.hidden_size, config.intermediate_size, bias=False, quantize=config.quantize) + self.down_proj = get_linear_layer(config.intermediate_size, config.hidden_size, bias=False, quantize=config.quantize) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) diff --git a/models/weight_loader.py b/models/weight_loader.py index 9bae224..51c68ba 100644 --- a/models/weight_loader.py +++ b/models/weight_loader.py @@ -1,6 +1,8 @@ import json import os +import gc import torch +import torch.nn as nn from safetensors.torch import load_file from models.llama import LlamaConfig, LlamaForCausalLM from models.qwen3 import QwenForCausalLM @@ -16,8 +18,95 @@ "qwen3": QwenForCausalLM, } -def load_hf_model(model_id:str, device:str = "cuda", dtype:torch.dtype = torch.bfloat16): - print(f"Loading model {model_id} to {device} with dtype {dtype}") + +def _remap_key(k: str) -> str: + """Map HuggingFace weight names -> our names.""" + if k.startswith("model."): + k = k[6:] + k = k.replace("self_attn.", "attn.") + k = k.replace("input_layernorm", "input_norm") + k = k.replace("post_attention_layernorm", "post_norm") + return k + + +def _resolve_parameter(model: nn.Module, key: str): + """Walk dot-separated key to find (parent_module, attr_name).""" + parts = key.split(".") + target = model + for part in parts[:-1]: + target = getattr(target, part) + return target, parts[-1] + + +def _find_shard_paths(model_id: str, local_only: bool) -> list[str]: + """Return list of safetensors shard paths (single file or multi-shard).""" + try: + path = hf_hub_download(repo_id=model_id, filename="model.safetensors", local_files_only=local_only) + return [path] + except Exception: + index_path = hf_hub_download(repo_id=model_id, filename="model.safetensors.index.json", local_files_only=local_only) + with open(index_path, "r") as f: + index = json.load(f) + shard_files = set(index["weight_map"].values()) + return [hf_hub_download(repo_id=model_id, filename=f, local_files_only=local_only) for f in shard_files] + + +def _load_standard(model, shard_paths, device, dtype): + """Fast path: load all weights at once via load_state_dict(assign=True).""" + state_dict = {} + for path in shard_paths: + state_dict.update(load_file(path, device=device)) + + mapped = {_remap_key(k): v.to(dtype) for k, v in state_dict.items()} + del state_dict + + missing, unexpected = model.load_state_dict(mapped, strict=False, assign=True) + del mapped + + return missing, unexpected + + +def _load_quantized(model, shard_paths, device, dtype): + """Quantized path: load shard-by-shard, quantize per-parameter via Params4bit.""" + from bitsandbytes.nn import Params4bit + + for path in shard_paths: + shard = load_file(path, device="cpu") # Always load to CPU first + + for k, v in shard.items(): + new_k = _remap_key(k) + + try: + target, attr_name = _resolve_parameter(model, new_k) + param = getattr(target, attr_name) + except AttributeError: + continue # Skip unmapped keys (e.g. keys we don't use) + + v_typed = v.to(dtype=dtype) + + if hasattr(param, "quant_type"): + # This is a Linear4bit parameter — quantize and place on device + new_param = Params4bit( + v_typed, + requires_grad=False, + quant_type=getattr(param, "quant_type", "nf4"), + ).to(device) + setattr(target, attr_name, new_param) + else: + # Normal parameter (embeddings, norms, lm_head, etc.) + target.register_parameter( + attr_name, + nn.Parameter(v_typed.to(device), requires_grad=False), + ) + + del shard + gc.collect() + if device != "cpu" and torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def load_hf_model(model_id: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, quantize: bool = False): + print(f"Loading model {model_id} to {device} with dtype {dtype} (quantize={quantize})") local_only = os.environ.get("HF_HUB_OFFLINE") == "1" config_path = hf_hub_download(repo_id=model_id, filename="config.json", local_files_only=local_only) @@ -38,6 +127,7 @@ def load_hf_model(model_id:str, device:str = "cuda", dtype:torch.dtype = torch.b attention_bias=hf.get("attention_bias", False), tie_word_embeddings=hf.get("tie_word_embeddings", False), head_dim=hf.get("head_dim"), + quantize=quantize, ) print(hf['architectures'][0]) @@ -49,45 +139,25 @@ def load_hf_model(model_id:str, device:str = "cuda", dtype:torch.dtype = torch.b model_class = MODEL_REGISTRY[model_type] - # Initialize model on meta device (no actual memory allocation) + # 1. Initialize model on meta device (no actual memory allocation) with torch.device("meta"): model = model_class(config) - # Load weights directly to target device/dtype to avoid CPU copies - try: - weights_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", local_files_only=local_only) - state_dict = load_file(weights_path, device=device) - except Exception: - index_path = hf_hub_download(repo_id=model_id, filename="model.safetensors.index.json", local_files_only=local_only) - with open(index_path, "r") as f: - index = json.load(f) - state_dict = {} - for shard in set(index["weight_map"].values()): - state_dict.update(load_file(hf_hub_download(repo_id=model_id, filename=shard, local_files_only=local_only), device=device)) - - # Map HF names -> our names - mapped = {} - for k,v in state_dict.items(): - new_k = k - if new_k.startswith("model."): - new_k = new_k[6:] - new_k = new_k.replace("self_attn.", "attn.") - new_k = new_k.replace("input_layernorm", "input_norm") - new_k = new_k.replace("post_attention_layernorm", "post_norm") - mapped[new_k] = v.to(dtype) - - del state_dict + # 2. Find shard files + shard_paths = _find_shard_paths(model_id, local_only) - # assign=True replaces meta tensors with real ones (no double allocation) - # Note: Tensors in 'mapped' are already on the target device and dtype. - missing, unexpected = model.load_state_dict(mapped, strict=False, assign=True) + # 3. Load weights — dual path strategy + if quantize: + _load_quantized(model, shard_paths, device, dtype) + else: + _load_standard(model, shard_paths, device, dtype) - # 1. Re-tie weights if they were tied in config. - # Using assign=True breaks existing tying because it replaces Parameter objects. + # 4. Re-tie weights if they were tied in config + # assign=True / per-param loading breaks existing tying because it replaces Parameter objects if config.tie_word_embeddings: model.lm_head.weight = model.embed_tokens.weight - # 2. Re-materialize RotaryEmbedding buffers on the target device. + # 5. Re-materialize RotaryEmbedding buffers on the target device. # These are computed buffers (not saved in checkpoints) that remain as # meta tensors after meta-device init + assign=True loading. from models.llama import RotaryEmbedding @@ -101,30 +171,30 @@ def load_hf_model(model_id:str, device:str = "cuda", dtype:torch.dtype = torch.b module.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) module.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - # 3. Ensure all remaining tensors (e.g. missing params, other buffers) are on the correct device/dtype - # We use to_empty() for any remaining meta tensors, then to() for the rest. + # 6. Ensure all remaining meta tensors are materialized on the correct device for param in model.parameters(): if param.is_meta: param.data = torch.empty_like(param, device=device) for buffer in model.buffers(): if buffer.is_meta: buffer.data = torch.empty_like(buffer, device=device) - + + # 7. Move model to target device/dtype + # Note: for quantized models, bnb parameters handle their own dtype, + # model.to() will skip them automatically model.to(device, dtype=dtype) - if missing: - # Filter out expected missing buffers (RoPE caches computed at runtime) - real_missing = [k for k in missing if "rotary_emb" not in k] + # 8. Check for missing parameters + missing_keys = [name for name, param in model.named_parameters() if param.is_meta] + if missing_keys: + real_missing = [k for k in missing_keys if "lm_head" not in k] if real_missing: print(f"Missing: {real_missing}") - if unexpected: - print(f"Unexpected: {unexpected}") - - del mapped config.device = device model.eval() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() return model, config diff --git a/pyproject.toml b/pyproject.toml index 834cd67..1498a57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "A fast and efficient LLM inference engine." readme = "README.md" requires-python = ">=3.12" dependencies = [ + "bitsandbytes>=0.49.2", "huggingface-hub>=1.11.0", "packaging>=26.1", "protobuf>=7.34.1", diff --git a/tests/__pycache__/conftest.cpython-312-pytest-9.0.3.pyc b/tests/__pycache__/conftest.cpython-312-pytest-9.0.3.pyc deleted file mode 100644 index e95fef759895effbf73b3c67f269dc04b6689ca9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 633 zcmZuvze^lJ6rR}~ch=($(j-RkI6?|J3%5y;KoCPLw(=y{WSix@9l6=e{$OUju#G4e z(9**~{{+MRDXESi>Oe>l2v`IZJ1b{A*R#kJ^S<}q_rCWH%#T{F0`&RQ_(umOnxs6= z1GxFC!4Zf+1Q`h6eY6t0V?mIhUt3;7;_C#}NjXMw&uz*c;Y0}juRA{X& zyO+j)^Om`hCa9i(DZWwM-M#9ich>Z1!A#jpPq0pyRYh4dXu8ktIH(h0KG+ zXMJ7S5W?R8d;|T;FYxl(0d95Q?{A;s`Uuy5;>H=CAL02EIKm4%=)!h(WdGg4$0L8> ikLU$_IaC7hf9 diff --git a/tests/__pycache__/test_main.cpython-312-pytest-9.0.3.pyc b/tests/__pycache__/test_main.cpython-312-pytest-9.0.3.pyc deleted file mode 100644 index 8d266024fc619a3623f1f5d2ad55b91c1b64616e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 26296 zcmeHQdvH|OdB6L(yLVq&Bmo9v%;n7jS|Rbq2HOaOAz0W3H;vL{v*=z)3;W8wD?n>0 zCb6fHCvHrdv{n*#qO^HH?aZWYn$D#CqevMtrjOZ05~3S-lD0|HbUGOaI;F#3{e9=` zeQ32S$u{w$qs!{M=kc9$&%O72-}zqm@OPn*U%+wlv0o*gTP6tqMj7L9mlL^z}&*U#Xt%g07xD zQNCcO(J}S!#h(j*w>IEW;l(%x967VT-|=cWJ>XQF1Fn8&zw7h{K^SoNyY>IcQ9E{= zGG4DyfV1N2cXv4yH?`?yR8m`m0zgh5H2A<-5#hv%yEs2Vx z`1-|!gXKKAu<>#CMx9qo$f zRJiMskGJ$cobAhjIzW1*`WO*xOjEO3XC?)TR7*ir$4>SplDR}$FL{y~VCbbKF})|# zn^YKTtUH-G8BfO2@swKf#q@X=ZaS-#>Kj8t39*bp4rATysw@1%x;~Yh0KpTyPJD5DYY4|rYD`~Y(Ajs zr*oNX^AlR;l-iloo6jVZsYE)_OgFDH7%MZg(mM7)7>CkhotbnFEH8Fa#U<5PXELMX z1;=vn?)`M z@T>`o(!Liz#dBj;(yYEF?HlSctBvG1>x>rz1J^Mdmx6D+xmzgT#NtfJOJvz}E2s8o^^V~j#qweG!b*FX@UCpDdBvj&O+A0ET z0PI|9hdr}P#~M{;S=xHSj4%$OskvS)t;aH2qC1fWb3NOmreo>M6!!B#4fZpa#NViE1S6NNA=-Fx2YIXkKO$gjuA>-wZqYfXZwSmt{8pQ--i6ClhH^ZmBG2^&3*0 z{~*zzYzT0nDOtv3v^vaBCQxz8C-Tx35m-!sID;kw*d-KB?`!N5L#h{Y7LCSKjAADt za-Gias6pyw8KqXcg>=c60f@>!Y^U;F7grRdT}0|Dis4<868$F11w3cMqO^-hoiN6% zq*;AU+C}tTt~Qe6tS?>+jEgHwd@})*d9ud45d>q6n^(n7y=_%&d6}$r2>sLbu#)f& z$#lz`OcFs zo`6j)dE4Y`SW>zYnw~S(-BdzRlBz-2Dl~D$s7l85T**!F9TN5iM%cZ&s>i4|yMWCp zo2{+cq?+GTgP`Zo_CAY{e>zp)hwKXhc&f-f@f#AkA(X#$@%kE!pEH_79Tg_g*KO0G zy0%XyQD#PcKWqE8yQrpXn?~VT@ZP%NQZx;h9>seX4VR87#!IPP+lfiZh$m7uJ8m<~ zj!0CP9b2cE9ij0uSvSI!V^WD*32f7v(P!RnZrsh4mtJG-%=GY6W>Yh9 zU0OfGB0Fz~druVKc{?23Ach@oiLpFGaZPgD8|2_?Fmgjo$W=cf~5U1dJ z@fM_}nZdY0K9f%N#WGoyTq-0~8x}CtByht(mVcxMjnBP{H8c*xB(}*eZC9+%?<12~ zd&FOoV$Agti7OgB=ofa&kHx5PVu3el2hD5U^juyHzgOceuuc>mIoEQ2Ws4jKl5;vQ@1{W zgx!a$x>Wa4ly$+e`n(%HT=L|4m?ziEJh>9{ zHyq*H1&2+-HGH$3PQ^w43@nFmTv#o%2)g4eT=*YB8;8*6I_LU`aMlrV=bIi=lUY^U zBQsyPtoP|r8Wl`RicCH?lpF5v4UuSGG7$3KvQMf#;?zWpN+W+U(#**GIMO=Zm@IQ+ zk|)Oc$uFGp1fsl@NKvOreqQ)lsNuq%Vc*a*7f%#dv`mCr-)LO^#RqS@gu1Phq1J1m z);l^~=L`O|u5XG_Z*{}#bvHe{e&G}(#LA=Fv}p|Bah|e2{4y$>>6Mb>cXD)x>^McQ zsNztZ%$Yu2a`YL9r>gi`j28xH`XXEZGs<-E%2*>I%U;D{`EOvcanFKlNq{r=X}PY& zS6pS+^;8Lv91-lWU6K3}6~}C@YrGG#3sxL>A5w)RKezYcs|apz=jPwXlxI7K_d!PD ziUaS%|F+&oVE%o)Gx|ViN3i`DUF`VxhnR*0AL3SuzJMCU12d4YG$c6vQ(*3Ks2r*l zn*lc>TYB_MF786yi#Mz_5ikXyU;M-0)A3&q;Q~!FT_BTr7?vj``c3dW&zP`?Qe2G?#;oKx;quaJ z)R&2FCV)&Ad-jEA>E_k|3_d%&#(r;&$3aEQ;*BsHtajMA+=98EO&yoQ=rVK#w6IS@zU6KL7;}ApP zZ)$?RpeCCVN?uN>xp-4D8BfKVj}mO!8f|TA+23?H4I683C%3bLf? zp9Qcu$SALwRZ7g@J;!zaZ}i92~&87~VfA(Ql$$z;h-nq8JoT7-Lp) zoRE3xHP)SpZzg~;Pjz?~r>aqsX?G=G=`7UrKBYQmQz{>N@~oO}Ca|?>+kWoC_&Dyl z8m(wCs;EBimrR}ah|1J*eB!U0Y2sUWd(gN3831eM9~f>fENdMLZ?#SQ=HW2h6oUi9 zVKXoYUJANQO_hg z)sm|-+iM8b(HW#yjy~s=#KZC5;X1N_(G%~A)6xA^=*U*T4{xP{evVue3(KNo;bu-h z5K9Zv>LC?i@R=dC7+yUo(Qkt1dB%iAl!AT;Va!U76D}{kMtzy+W&#-Viw54s|KCMJ z(Nx?YUn;sNRYk1Qx~Zc2QnA=RwXFppzozyCx3Ko3U;Wbn_GM!2(B83d`h$>Fa|mfjOIUC$tpinU*0? zZS}}fMfKHVIa@uZ&R%PI$Zx8$h}1$pf~)Ay1DF$j=)lm9vG6LklDvPW`(05wTyN2a zkB>?b9VJ%S@ICLbPsMZBHEpkH-O@fwt~h*6Gh9DH$DU4>#Qk=k+f3bWL)i%%DpZ5<15D-+ARXXbtv zjKTB7f){Qa{gyHKT_F}Ph$WM4%I4i$qPw?ubN|<`%5?Hu7}ZR4@)4?kXLJ(gbVBYW zP9e*yQHWVhlv|;AA09qqlbg6hX5|p1rj>BWoPZEDDba6&=Xu72MU+BjB7`w3IZn8| z^qPdeOms5=j2T3x`89-3XBe(Nple7FK9cx3kOuitORnuLtqJU{QP#shJPTQLymPWx z$jHKSw{qH8RfRSxaEN3p2&Dbuqjt+Kb0V%CUR{vZ!fCj=7+yOm(Qkt1dB%iAl!8DA zVa!U76D}{kCZR7A-An*uI)PAl4~xJeInW+~RV`E~=LQlxj1Q$KCJTW_6jxYrC=N@{ z$EGOkpB1CkH$6s6D930y+w*lzO}**U>CD-56ZwuJ?vjV^jZpGv@dUmQDEW`~=~$#6 zJ(tKa|5e1n0Lj;dfmH-V_ zr^@GrH=Qm~9@;kS`tt5u0+KiVZgKt1pi^vfn1O?9D8;}=N-?mWQmY88y9KbBUxqrT z5t;{uqy_MIi|)WaX~^4*U_5-GuXtYbzDkn|2Q%}~B&#_3&Db;q_L(tRij)xg9p7+# z-fIEWBWHEURTYRh>5Q;C*c1TLZ}Ixttd7q}@ko%NDwx zj$S)^a7Aof6p{b+{FmfJ7we6^Y9*uU>Ge6Z10BhIYHqW97PG(@>`ApJekPIW#o!WY z_9gL|L|oQ;PnNGQH_It}!pny@B@C73GQHGWG8Pp2#s+&q2`OO!0i@PfC9eX3jscf7 zSIsh8?4enPDB6CpTKr_pWTQE{GCQy}k?S?lV2~`-)vELKSLwgYSrZ>$tJJKM?5x{% zeP4}ozY>_*ms99>8utjs9;AQ9xbJ?xFIXAO`(i6!D5jO~s8PPBM)|%P<@>9b_XT-5 z_TqZ@cmf}zffdLZe6QY2%L=9(K6hbZ#1Kyqf5zrBBOP!I@GfBMa%WPFBa%L`Xk_tKMl~2+2%biS=|*gt?l$vx=2H(wobi zAjv3^$IeDw=nqtcszx#>9jj_#;!t9b`N(%c0=vp zji1LT8S`Q+uhd9!VSO=T3Ng}<=?Q4RPdLQ9fxi-`lBYACMbtEJOvWw5sAW1&^Wm22 zSeglOG}LLI8*uX$JyC;Q#7<%}r_=Z-5Ns&lO021paq+!=YPF#yRR2eS^TN-iItNdu@p7C z!wh$qT}uwQU6+^Dwe1owM}lpey_dHK+hpH}qpodXU?eQHNuiN@BucN4sC>NyoRMY; z_#-lbiw)*5P)(ia6&0^d3z-WsXtwwi=2OPf9 zyb`hpTd-Qdz_TEYMU$9>=qRFA6xDuF&DynW(O>HuGc&UUt;{!s`F0g z;L}UMUu(6;J~}n)rt(qin*J&L)hZ8wSDXBk##zB%Lkj*G_W}NTm-C204;#6uob^{oY`n~ zm{OmkfLL~%y4OfqXJfSN%K~d06g69MSP|g(vcSTuc6D>xIcUuFv*4O*s?t!lruvme zWnsTR!8f}LZbTCnUeS1bO*F@JdAw~N=Oi%44fC#>Y;XUOgqq#OjwHF9{fxs6De2Ksx z5cop^wCfEdnSf3XVzsLad;YcO1GI+Io?&V(%G}FoxR*DP!x;N2LY6Xo$2`a%PULcs zKQ?5HLb>w~lq8556z8uXs=tA*L*!f$)t?X8?g#q#RCz}GGPNN2rQ}ZFIWZ}wV0r-wen{Fh7Yu+E)m$3mvD_D*z=A6RdRY8LzJj!QSo(JG>p>(C{TM8SH&06Ro8Wn#8KVqOnvk_LtA{0? z97lga6UR)TgLm&@7{l$55ZYCQc*mYb(U1mhfD2Ny}{e2bu5UWUNT(Oin(j%5gWd? z!M1hY%gw>IdwrK5sB5bWjQFHBXK18eqV!S;jlne zo!E9yc(lCO z0}B2b^-wnNavsWNz!-;hHl7Z42oDZ;U@|2EsNx{Cr{4=@Q|gzXY?>--HN{j3LCJI| zb=1P9Y+7wdT&y@O#ZU?7XuC*uoT{Q28lz>(rq$;Xqm6a$7Y!v2%BF=`?dm2bl+Exg zxaPZ87L=7upHiZvKfG)>A{~g;Oz@kQE&V;)sAq@ znuRrMme#Dfzh=#9UK44|A9(cGLAIm0yg`k2N9AlHi?u?}DbOF(b3O51xCoVqv7j(3 zt?9LZ<3T+chd02c<;atEzRno$9Qra{T(%N%mj*W2VlxA8m8R0AiUE;c~L%8@47729de787Y7gS z^j=;SJh;(!Iofb=d0<444lW3dND`$NNGKm!A^~SamViI9K|)?ZmfI9Ha7 z2iJ$MtRl=S8^wd$LRXpz|H@ABkRx#AL2C6Qw|J<|_ah&*{ZYMmXc=7AQl-n<8;dEK zPV8IrM>MI3_i5iD@Mi?bIcjr6e~z*QGWgSp2|X|T+UXXho56aq@#eb4V$;nnNZs1p zAoA#Dh8#(&4NbktlJp?lF4)*- zpY{XbVa5yk0f0AM4u|8Pgy>J`rl0fwSP1<@*isa>{IhV+FN96ErDcvSzZL*)?`w5z c{b-CjqC5<#c z@+DjGw&)@+h1Yrr;FtEC zd&xW0XrxhOCvj3;%pBf(d3bqwf6lr0JBR$2SS&0+xcTn?WDnK}!hd4I4SpxGRDj5Y zpb46o6S_ru^yQ_DPZaY0Zdqh!0^I?6CfFT>qc5lAL){^IPRfPzk?x2n_=Fz{n*TjP zlkfOg+1*i^7l6DV&x?(SDdlry=t_xp@QrkSFsBoU#%VNL=zT3s(s?7p%3xG<5=&G#^mECIOYP0#H>MXi$S2cPm;DXh>6lhDSoFu)XC?-7qWe<(7}P z3Bq53|I$B06HN#vKX(VuLMVxr*zY_d%VD$x^(je~&yofX0+-+N8-HWUE=PaBjg^m@ z&+_+(8oj2+cgN`x&2NTWo_^YaZi*eyK@2FO-kaAE6D+kML%HJlgyi5mJ$^-?E;#-#INl z^mTMhi=)Sz4-SqTOBvyxIFV#W5@W3?qs7$ogE~o@Lqrdkk{#^C$&_J>gVT~N0TYsA zNKQVCrwv0VCcXF+@IdA%DF4V>)4)eb zk^?~PBppEoy%#3NRbBPWQI!=kdVIBEx5h=Nk&L}SU?j#GZd@9_^uY@YvHdVAr^rI| z&=-;rKd>0vKO5Wsa7h8SU&UIbyYlJa6WPeUzxpMH-+@>ryP>k=2$D7+-iB^!?_Gwh_zP@YeqoCe&cII>z(RAtBUxyr=bvJl<3 z_Q*krM5WrtARxR_%`2fE4A8S*re?MQYYp%bbRW4n+%gCq&WyRKoX zV>_R`11g=7Swp>=&SkaH!>f&16n7QODCP`Qqjb?SRX-nXV#^k~lK2|5J4wOKmr2mg zc_0AB@GYGOxSU`bht7irOya_Lm3crF`lHQ*r};b>q&lBZgF|aggXO&Yp@yDW&@dlu z_RNB>dOZy*?#)>MJ-|X4ll%70g46IQe)@dSxh`Lj*!q z2_VV>IfkSi$qPt2kaQwBhvYnvX@QnXRVcxIBqi8E20rk6!ULaQ-t{2rlMg>oeDZ!y znlZKu$?hc}`&cnpLI(ak;NA|y()WRE9EjIK0LMp0!tn4B-g7!xrYCVdl(Qirre&(9NMiHLsg}MQwDu=LuW9XHGz=2C4`cg7w?q z-!bO@Mc~tIjT5Zd6YS;msfdC1V_zJ6RwSA>$}6D-MXEiku)ORy!d6xur>D`fnG z=dRUzCj`9{!E3&Qy(7DA>1l;(BSUfdwCD$ZYN`rqF)eNdon8#DyBEW5N$%0xi_nvy zk^%xK(M7mQo%Vf>QjPo}5>&4>AaSEFJy-$JOT{6xVnIragObi8B!9w%{2}5%{;OC4 zvVhoyO9xYB$Q|JQX+Oc8E5R9yQ2z2HIL(Xzr>Em=KiGw=0SA139TPIG_j*?a zwjXOU9iadt0Fl%}O#TRKfjd0Xgk&F*W+b@Y^#bY+hR$&JlMMf7=z3eLfkMJT)bZfm zefAPv*GAOAY-Yd!TS9L>Jv_S4E6q&=$g6NY4`mt&=tq?7&)z;iU*9?w_$0pNM$dT9 zjr@3i^1@%g^2=9lZ<~v^j!B;+w%iyQADL11PF@5uHZpl}HnDe6!QTQ)vjnHJkZZ8i zTsbj0voiC4No58K z55#4lLYW1>VPzJqA>O{cGE=xR3u)nROquz&DX*mRq+0kOT6i|s+R2Rw0f0-m29ClsA9?-ORi<{THGp38)_F<*aXQ&?9O z)`-_2aJAI#-$aSn_;zU-`}uC8L=^w2jueFeDGIsDtrWHtD?9**QYlg$3#^9zX^|b! zprxo4wxR>DIV;Ilg!ljhT=Z-}s-a9#3lOFrpBBTrTG1ZKb$#|+w>B>Gq~nDl7w@U@ zL>wQg!54MQ)8Ze<6p@N)39ZhGIfztz-H24&Ey+DPh?D|oR0Od_M5W*#E=2*Es_!Dy zfC*<1?X} z#R0uAdV1xCXq}h73^gz=ls8FF5};K#qqN^VKk?G7mmqTMrT1Tk!`PLZFV800AvU%3 zqs9*#7xx^S-E-_N5X`&z^30xgIK8MiKj#X_Wm$7r0zbKdAJ|09GyutX%LC4p45~0MkuX<4!+f%M}59rOGBB z$&#wQeNyskG6;Gw<4~hQzzTrL2ek0C|MA%@fa19^4lN%RSY?`sJnU^`y;9L>FGy+2y1DJf28ob<;Q-hD=Fj~k#$%0xKOen6gKrvBE(%>nn zit)zcW(T*Hj%fCSI1ZF#5XVuEp@6&sWu_#@Ou5uaegx^>9vc3;Gv)i7%3^LtAgh5c zDF>r}6)q>QAwjgB{3#L@2{z<9{J2sfkB~%+e{D%=;P~T0jiK($Q}J@z%=CdXKG-$X zL88N(1y^-#bpHz3Rh17Pz_<;>i@bkYzindItzCAs5q|e&>Dlm-AZ>|^1@0?CqJe|* z7pE@H#Sed-fIL9eta%>~{AMgHElD&hJG|X%+ zmtpc^5M#ZvdrP;e@MYjr8Rlh~t%JP=>8jS6!D2 zk$eZO`7bnsZbIwSfLax;;cmg3QlHIDFWr>-%e;CvrM_oySbPCq14H4<>#t4f{{d8! z0@xjGEgrdF7|P`IA)`X_!&|CM!W&6Fm+j4hn0+qEd`ptl10~5kP#gt7VVMEA!B?sP zyziai8)t>02dzJ_Vo6RPmb$5X8Gj75T&4y#L3{M0sj3FBTPbKizb0s3_J-A@^81~1 zua7H(im)Ls8R?NdZhYK(2V6v}0mvm8Kxo6ZeUzZA3rR*)hKilx+aS41~O^uK{? zJjlZ1vl(){pr)CCa{{tZYqPOlfyKXiQA4nCS5ZxIma`!Z{d7W)#a+<_g#)mhRZ2lE z3MK-7)!tN~Hx=HS2=Yg%Ng=ih0#yYA(PGr35VxXMoSGB@Rt!uEftn_TxD^AG$p@BF zxT_Ukw@D%5Vl7(TWAobL$WhHv!#QEp)Ya%^uD)vUaZ^DJK5hum>K&6pLffJ>SP5s; z)UDg3Q0KOxdj#AHM-BdC@TVgOj?q#aHbYxSJKiwUrk>LcB8M?|F zSY-|Dc@f_=z1O}7y#|awr&a;FzoS-xPaF2!&dfI)8&f`eYRAM|x89n#bnDX8iI2{F zc;;?k?x|B_;ZL4!oVa@H>O|>QY3j|7-udvIy9eizdn~p z!g1;xkeNhs>fCIic~QaNf|J2goX#rE=zxkGbLHqfM8UBYbMMjXczo&{i+U*({aFI+ z37t28w79i-c55?4fzE9`0FkjPGg}YLCORQ{d*9z4o<6*I=*;Y)Gk^E}*+ZBICua_I z!udtT`MK9XHp`pCLilMG0VcQ^y#!9X(RnNnDvcN3qh*)lZfPusQ(&W5Vi$Z8tIK2|&JhP3#n36dwpk>2v-c DJ?rHR diff --git a/uv.lock b/uv.lock index 99219f7..0e4f242 100644 --- a/uv.lock +++ b/uv.lock @@ -24,6 +24,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/42/e921fccf5015463e32a3cf6ee7f980a6ed0f395ceeaa45060b61d86486c2/anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708", size = 114353, upload-time = "2026-03-24T12:59:08.246Z" }, ] +[[package]] +name = "bitsandbytes" +version = "0.49.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "torch" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/d8/7d/f1fe0992334b18cd8494f89aeec1dcc674635584fcd9f115784fea3a1d05/bitsandbytes-0.49.2-py3-none-macosx_14_0_arm64.whl", hash = "sha256:87be5975edeac5396d699ecbc39dfc47cf2c026daaf2d5852a94368611a6823f", size = 131940, upload-time = "2026-02-16T21:26:04.572Z" }, + { url = "https://files.pythonhosted.org/packages/29/71/acff7af06c818664aa87ff73e17a52c7788ad746b72aea09d3cb8e424348/bitsandbytes-0.49.2-py3-none-manylinux_2_24_aarch64.whl", hash = "sha256:2fc0830c5f7169be36e60e11f2be067c8f812dfcb829801a8703735842450750", size = 31442815, upload-time = "2026-02-16T21:26:06.783Z" }, + { url = "https://files.pythonhosted.org/packages/19/57/3443d6f183436fbdaf5000aac332c4d5ddb056665d459244a5608e98ae92/bitsandbytes-0.49.2-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:54b771f06e1a3c73af5c7f16ccf0fc23a846052813d4b008d10cb6e017dd1c8c", size = 60651714, upload-time = "2026-02-16T21:26:11.579Z" }, + { url = "https://files.pythonhosted.org/packages/b6/d4/501655842ad6771fb077f576d78cbedb5445d15b1c3c91343ed58ca46f0e/bitsandbytes-0.49.2-py3-none-win_amd64.whl", hash = "sha256:2e0ddd09cd778155388023cbe81f00afbb7c000c214caef3ce83386e7144df7d", size = 55372289, upload-time = "2026-02-16T21:26:16.267Z" }, +] + [[package]] name = "certifi" version = "2026.2.25" @@ -976,6 +992,7 @@ name = "vllmini" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "bitsandbytes" }, { name = "huggingface-hub" }, { name = "packaging" }, { name = "protobuf" }, @@ -987,6 +1004,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "bitsandbytes", specifier = ">=0.49.2" }, { name = "huggingface-hub", specifier = ">=1.11.0" }, { name = "packaging", specifier = ">=26.1" }, { name = "protobuf", specifier = ">=7.34.1" }, From 01a79f38d90ecbc11f780d0895b745fe56b58a79 Mon Sep 17 00:00:00 2001 From: Lothnic Date: Mon, 27 Apr 2026 19:37:53 +0530 Subject: [PATCH 2/6] fixed quantisation and ready to merge --- .gitignore | 4 +++- engine/generator.py | 12 ++++++++-- main.py | 35 +++++++++++++---------------- models/weight_loader.py | 44 +++++++++++++++++++++++++++--------- pyproject.toml | 1 + uv.lock | 50 +++++++++++++++++++++++++++++++++++++++++ 6 files changed, 114 insertions(+), 32 deletions(-) diff --git a/.gitignore b/.gitignore index dc1108d..16c8f68 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ future_plans/ **/__pycache__/ -*.pyc \ No newline at end of file +*.pyc + +docs/quantisation.md \ No newline at end of file diff --git a/engine/generator.py b/engine/generator.py index 8dab21f..b3c477f 100644 --- a/engine/generator.py +++ b/engine/generator.py @@ -12,7 +12,9 @@ def __init__(self, model, tokenizer, sampler: Sampler | None = None): @torch.inference_mode() def generate(self, prompt: str, max_new_tokens: int = 50, params: SamplingParams | None = None): input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.config.device) + prompt_len = input_ids.shape[1] past_key_values = None + prev_text = "" for _ in range(max_new_tokens): if past_key_values is None: @@ -25,5 +27,11 @@ def generate(self, prompt: str, max_new_tokens: int = 50, params: SamplingParams if next_token.item() == self.tokenizer.eos_token_id: break - - yield self.tokenizer.decode(next_token[0], skip_special_tokens=True) \ No newline at end of file + + # Decode all generated tokens so far and yield only the new text. + # This correctly handles SentencePiece space prefixes and multi-byte chars. + full_text = self.tokenizer.decode(input_ids[0, prompt_len:], skip_special_tokens=True) + new_text = full_text[len(prev_text):] + prev_text = full_text + if new_text: + yield new_text \ No newline at end of file diff --git a/main.py b/main.py index 94cf902..4049548 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,9 @@ """CLI entry point.""" import os +import warnings +warnings.filterwarnings("ignore", message=".*_check_is_size.*", category=FutureWarning) import argparse import torch -from huggingface_hub import try_to_load_from_cache from transformers import AutoTokenizer from models.weight_loader import load_hf_model @@ -38,27 +39,18 @@ def strip_thinking(output: str) -> str: def main(): args = parse_args() - # Resolve local cache path for the actual requested model - # If cached, enable offline mode to prevent any network calls - _cached = try_to_load_from_cache(args.model_id, "config.json") - local_model_path = os.path.dirname(_cached) if isinstance(_cached, str) else None - - if local_model_path: - os.environ["HF_HUB_OFFLINE"] = "1" - os.environ["TRANSFORMERS_OFFLINE"] = "1" - else: - # Ensure offline mode is NOT set if the model isn't cached - os.environ.pop("HF_HUB_OFFLINE", None) - os.environ.pop("TRANSFORMERS_OFFLINE", None) + # Don't force offline mode for model loading — the weight_loader + # handles local-first-then-download fallback on its own. + os.environ.pop("HF_HUB_OFFLINE", None) + os.environ.pop("TRANSFORMERS_OFFLINE", None) model, config = load_hf_model(args.model_id, device=args.device, quantize=args.quantize) + + tokenizer = AutoTokenizer.from_pretrained(args.model_id) + # chat = [{"role": "user", "content": "Write a short story about a robot."}] + # prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) - # Use local model path if available (prevents network calls for tokenizer) - tokenizer = AutoTokenizer.from_pretrained(local_model_path or args.model_id) - chat = [{"role": "user", "content": "Write a short story about a robot."}] - prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) - - # prompt = "Write a very long story about a robot." + prompt = "Write a very long story about a robot." params = SamplingParams(temperature=args.temperature, top_p=args.top_p) sampler = Sampler() @@ -122,6 +114,11 @@ def main(): if remainder: print(remainder, end="", flush=True) parts.append(remainder) + elif not indicator_shown and len(buffer) > 20: + # Model doesn't use tags — flush buffer and stream normally + thinking_done = True + print(buffer, end="", flush=True) + parts.append(buffer) # Otherwise keep accumulating silently else: # Either HIDE_THINKING is False, or we're past diff --git a/models/weight_loader.py b/models/weight_loader.py index 51c68ba..3df02c9 100644 --- a/models/weight_loader.py +++ b/models/weight_loader.py @@ -39,16 +39,33 @@ def _resolve_parameter(model: nn.Module, key: str): def _find_shard_paths(model_id: str, local_only: bool) -> list[str]: - """Return list of safetensors shard paths (single file or multi-shard).""" + """Return list of safetensors shard paths (single file or multi-shard). + + Tries local cache first. If the file isn't cached *and* ``local_only`` + is ``False``, transparently falls back to downloading from the Hub. + """ + + def _download(filename: str, *, must_be_local: bool) -> str: + """Try local first, then online if allowed.""" + try: + return hf_hub_download(repo_id=model_id, filename=filename, local_files_only=True) + except Exception: + if must_be_local: + raise + return hf_hub_download(repo_id=model_id, filename=filename, local_files_only=False) + + # 1. Try single-file model try: - path = hf_hub_download(repo_id=model_id, filename="model.safetensors", local_files_only=local_only) - return [path] + return [_download("model.safetensors", must_be_local=local_only)] except Exception: - index_path = hf_hub_download(repo_id=model_id, filename="model.safetensors.index.json", local_files_only=local_only) - with open(index_path, "r") as f: - index = json.load(f) - shard_files = set(index["weight_map"].values()) - return [hf_hub_download(repo_id=model_id, filename=f, local_files_only=local_only) for f in shard_files] + pass + + # 2. Multi-shard model — get the index first, then each shard + index_path = _download("model.safetensors.index.json", must_be_local=local_only) + with open(index_path, "r") as f: + index = json.load(f) + shard_files = sorted(set(index["weight_map"].values())) + return [_download(f, must_be_local=local_only) for f in shard_files] def _load_standard(model, shard_paths, device, dtype): @@ -73,16 +90,20 @@ def _load_quantized(model, shard_paths, device, dtype): for path in shard_paths: shard = load_file(path, device="cpu") # Always load to CPU first - for k, v in shard.items(): + keys = list(shard.keys()) + for k in keys: + v = shard.pop(k) # Pop to free memory as we iterate new_k = _remap_key(k) try: target, attr_name = _resolve_parameter(model, new_k) param = getattr(target, attr_name) except AttributeError: + del v continue # Skip unmapped keys (e.g. keys we don't use) v_typed = v.to(dtype=dtype) + del v # Free original tensor immediately if hasattr(param, "quant_type"): # This is a Linear4bit parameter — quantize and place on device @@ -90,7 +111,9 @@ def _load_quantized(model, shard_paths, device, dtype): v_typed, requires_grad=False, quant_type=getattr(param, "quant_type", "nf4"), - ).to(device) + ) + del v_typed # Free CPU copy before GPU allocation + new_param = new_param.to(device) setattr(target, attr_name, new_param) else: # Normal parameter (embeddings, norms, lm_head, etc.) @@ -98,6 +121,7 @@ def _load_quantized(model, shard_paths, device, dtype): attr_name, nn.Parameter(v_typed.to(device), requires_grad=False), ) + del v_typed del shard gc.collect() diff --git a/pyproject.toml b/pyproject.toml index 1498a57..671fb23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "protobuf>=7.34.1", "pytest>=9.0.3", "safetensors>=0.7.0", + "sentencepiece>=0.2.1", "torch>=2.11.0", "transformers>=5.5.4", ] diff --git a/uv.lock b/uv.lock index 0e4f242..1deac9c 100644 --- a/uv.lock +++ b/uv.lock @@ -815,6 +815,54 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/e6/ec8471c8072382cb91233ba7267fd931219753bb43814cbc71757bfd4dab/safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755", size = 341380, upload-time = "2025-11-19T15:18:44.427Z" }, ] +[[package]] +name = "sentencepiece" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/15/2e7a025fc62d764b151ae6d0f2a92f8081755ebe8d4a64099accc6f77ba6/sentencepiece-0.2.1.tar.gz", hash = "sha256:8138cec27c2f2282f4a34d9a016e3374cd40e5c6e9cb335063db66a0a3b71fad", size = 3228515, upload-time = "2025-08-12T07:00:51.718Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/be/32ce495aa1d0e0c323dcb1ba87096037358edee539cac5baf8755a6bd396/sentencepiece-0.2.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:57cae326c8727de58c85977b175af132a7138d84c764635d7e71bbee7e774133", size = 1943152, upload-time = "2025-08-12T06:59:40.048Z" }, + { url = "https://files.pythonhosted.org/packages/88/7e/ff23008899a58678e98c6ff592bf4d368eee5a71af96d0df6b38a039dd4f/sentencepiece-0.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:56dd39a3c4d6493db3cdca7e8cc68c6b633f0d4195495cbadfcf5af8a22d05a6", size = 1325651, upload-time = "2025-08-12T06:59:41.536Z" }, + { url = "https://files.pythonhosted.org/packages/19/84/42eb3ce4796777a1b5d3699dfd4dca85113e68b637f194a6c8d786f16a04/sentencepiece-0.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d9381351182ff9888cc80e41c632e7e274b106f450de33d67a9e8f6043da6f76", size = 1253645, upload-time = "2025-08-12T06:59:42.903Z" }, + { url = "https://files.pythonhosted.org/packages/89/fa/d3d5ebcba3cb9e6d3775a096251860c41a6bc53a1b9461151df83fe93255/sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99f955df238021bf11f0fc37cdb54fd5e5b5f7fd30ecc3d93fb48b6815437167", size = 1316273, upload-time = "2025-08-12T06:59:44.476Z" }, + { url = "https://files.pythonhosted.org/packages/04/88/14f2f4a2b922d8b39be45bf63d79e6cd3a9b2f248b2fcb98a69b12af12f5/sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0cdfecef430d985f1c2bcbfff3defd1d95dae876fbd0173376012d2d7d24044b", size = 1387881, upload-time = "2025-08-12T06:59:46.09Z" }, + { url = "https://files.pythonhosted.org/packages/fd/b8/903e5ccb77b4ef140605d5d71b4f9e0ad95d456d6184688073ed11712809/sentencepiece-0.2.1-cp312-cp312-win32.whl", hash = "sha256:a483fd29a34c3e34c39ac5556b0a90942bec253d260235729e50976f5dba1068", size = 999540, upload-time = "2025-08-12T06:59:48.023Z" }, + { url = "https://files.pythonhosted.org/packages/2d/81/92df5673c067148c2545b1bfe49adfd775bcc3a169a047f5a0e6575ddaca/sentencepiece-0.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:4cdc7c36234fda305e85c32949c5211faaf8dd886096c7cea289ddc12a2d02de", size = 1054671, upload-time = "2025-08-12T06:59:49.895Z" }, + { url = "https://files.pythonhosted.org/packages/fe/02/c5e3bc518655d714622bec87d83db9cdba1cd0619a4a04e2109751c4f47f/sentencepiece-0.2.1-cp312-cp312-win_arm64.whl", hash = "sha256:daeb5e9e9fcad012324807856113708614d534f596d5008638eb9b40112cd9e4", size = 1033923, upload-time = "2025-08-12T06:59:51.952Z" }, + { url = "https://files.pythonhosted.org/packages/ba/4a/85fbe1706d4d04a7e826b53f327c4b80f849cf1c7b7c5e31a20a97d8f28b/sentencepiece-0.2.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dcd8161eee7b41aae57ded06272905dbd680a0a04b91edd0f64790c796b2f706", size = 1943150, upload-time = "2025-08-12T06:59:53.588Z" }, + { url = "https://files.pythonhosted.org/packages/c2/83/4cfb393e287509fc2155480b9d184706ef8d9fa8cbf5505d02a5792bf220/sentencepiece-0.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c6c8f42949f419ff8c7e9960dbadcfbc982d7b5efc2f6748210d3dd53a7de062", size = 1325651, upload-time = "2025-08-12T06:59:55.073Z" }, + { url = "https://files.pythonhosted.org/packages/8d/de/5a007fb53b1ab0aafc69d11a5a3dd72a289d5a3e78dcf2c3a3d9b14ffe93/sentencepiece-0.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:097f3394e99456e9e4efba1737c3749d7e23563dd1588ce71a3d007f25475fff", size = 1253641, upload-time = "2025-08-12T06:59:56.562Z" }, + { url = "https://files.pythonhosted.org/packages/2c/d2/f552be5928105588f4f4d66ee37dd4c61460d8097e62d0e2e0eec41bc61d/sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d7b670879c370d350557edabadbad1f6561a9e6968126e6debca4029e5547820", size = 1316271, upload-time = "2025-08-12T06:59:58.109Z" }, + { url = "https://files.pythonhosted.org/packages/96/df/0cfe748ace5485be740fed9476dee7877f109da32ed0d280312c94ec259f/sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c7f0fd2f2693309e6628aeeb2e2faf6edd221134dfccac3308ca0de01f8dab47", size = 1387882, upload-time = "2025-08-12T07:00:00.701Z" }, + { url = "https://files.pythonhosted.org/packages/ac/dd/f7774d42a881ced8e1739f393ab1e82ece39fc9abd4779e28050c2e975b5/sentencepiece-0.2.1-cp313-cp313-win32.whl", hash = "sha256:92b3816aa2339355fda2c8c4e021a5de92180b00aaccaf5e2808972e77a4b22f", size = 999541, upload-time = "2025-08-12T07:00:02.709Z" }, + { url = "https://files.pythonhosted.org/packages/dd/e9/932b9eae6fd7019548321eee1ab8d5e3b3d1294df9d9a0c9ac517c7b636d/sentencepiece-0.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:10ed3dab2044c47f7a2e7b4969b0c430420cdd45735d78c8f853191fa0e3148b", size = 1054669, upload-time = "2025-08-12T07:00:04.915Z" }, + { url = "https://files.pythonhosted.org/packages/c9/3a/76488a00ea7d6931689cda28726a1447d66bf1a4837943489314593d5596/sentencepiece-0.2.1-cp313-cp313-win_arm64.whl", hash = "sha256:ac650534e2251083c5f75dde4ff28896ce7c8904133dc8fef42780f4d5588fcd", size = 1033922, upload-time = "2025-08-12T07:00:06.496Z" }, + { url = "https://files.pythonhosted.org/packages/4a/b6/08fe2ce819e02ccb0296f4843e3f195764ce9829cbda61b7513f29b95718/sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:8dd4b477a7b069648d19363aad0cab9bad2f4e83b2d179be668efa672500dc94", size = 1946052, upload-time = "2025-08-12T07:00:08.136Z" }, + { url = "https://files.pythonhosted.org/packages/ab/d9/1ea0e740591ff4c6fc2b6eb1d7510d02f3fb885093f19b2f3abd1363b402/sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0c0f672da370cc490e4c59d89e12289778310a0e71d176c541e4834759e1ae07", size = 1327408, upload-time = "2025-08-12T07:00:09.572Z" }, + { url = "https://files.pythonhosted.org/packages/99/7e/1fb26e8a21613f6200e1ab88824d5d203714162cf2883248b517deb500b7/sentencepiece-0.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:ad8493bea8432dae8d6830365352350f3b4144415a1d09c4c8cb8d30cf3b6c3c", size = 1254857, upload-time = "2025-08-12T07:00:11.021Z" }, + { url = "https://files.pythonhosted.org/packages/bc/85/c72fd1f3c7a6010544d6ae07f8ddb38b5e2a7e33bd4318f87266c0bbafbf/sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b81a24733726e3678d2db63619acc5a8dccd074f7aa7a54ecd5ca33ca6d2d596", size = 1315722, upload-time = "2025-08-12T07:00:12.989Z" }, + { url = "https://files.pythonhosted.org/packages/4a/e8/661e5bd82a8aa641fd6c1020bd0e890ef73230a2b7215ddf9c8cd8e941c2/sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0a81799d0a68d618e89063fb423c3001a034c893069135ffe51fee439ae474d6", size = 1387452, upload-time = "2025-08-12T07:00:15.088Z" }, + { url = "https://files.pythonhosted.org/packages/99/5e/ae66c361023a470afcbc1fbb8da722c72ea678a2fcd9a18f1a12598c7501/sentencepiece-0.2.1-cp313-cp313t-win32.whl", hash = "sha256:89a3ea015517c42c0341d0d962f3e6aaf2cf10d71b1932d475c44ba48d00aa2b", size = 1002501, upload-time = "2025-08-12T07:00:16.966Z" }, + { url = "https://files.pythonhosted.org/packages/c1/03/d332828c4ff764e16c1b56c2c8f9a33488bbe796b53fb6b9c4205ddbf167/sentencepiece-0.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:33f068c9382dc2e7c228eedfd8163b52baa86bb92f50d0488bf2b7da7032e484", size = 1057555, upload-time = "2025-08-12T07:00:18.573Z" }, + { url = "https://files.pythonhosted.org/packages/88/14/5aee0bf0864df9bd82bd59e7711362908e4935e3f9cdc1f57246b5d5c9b9/sentencepiece-0.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:b3616ad246f360e52c85781e47682d31abfb6554c779e42b65333d4b5f44ecc0", size = 1036042, upload-time = "2025-08-12T07:00:20.209Z" }, + { url = "https://files.pythonhosted.org/packages/24/9c/89eb8b2052f720a612478baf11c8227dcf1dc28cd4ea4c0c19506b5af2a2/sentencepiece-0.2.1-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:5d0350b686c320068702116276cfb26c066dc7e65cfef173980b11bb4d606719", size = 1943147, upload-time = "2025-08-12T07:00:21.809Z" }, + { url = "https://files.pythonhosted.org/packages/82/0b/a1432bc87f97c2ace36386ca23e8bd3b91fb40581b5e6148d24b24186419/sentencepiece-0.2.1-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:c7f54a31cde6fa5cb030370566f68152a742f433f8d2be458463d06c208aef33", size = 1325624, upload-time = "2025-08-12T07:00:23.289Z" }, + { url = "https://files.pythonhosted.org/packages/ea/99/bbe054ebb5a5039457c590e0a4156ed073fb0fe9ce4f7523404dd5b37463/sentencepiece-0.2.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c83b85ab2d6576607f31df77ff86f28182be4a8de6d175d2c33ca609925f5da1", size = 1253670, upload-time = "2025-08-12T07:00:24.69Z" }, + { url = "https://files.pythonhosted.org/packages/19/ad/d5c7075f701bd97971d7c2ac2904f227566f51ef0838dfbdfdccb58cd212/sentencepiece-0.2.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1855f57db07b51fb51ed6c9c452f570624d2b169b36f0f79ef71a6e6c618cd8b", size = 1316247, upload-time = "2025-08-12T07:00:26.435Z" }, + { url = "https://files.pythonhosted.org/packages/fb/03/35fbe5f3d9a7435eebd0b473e09584bd3cc354ce118b960445b060d33781/sentencepiece-0.2.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01e6912125cb45d3792f530a4d38f8e21bf884d6b4d4ade1b2de5cf7a8d2a52b", size = 1387894, upload-time = "2025-08-12T07:00:28.339Z" }, + { url = "https://files.pythonhosted.org/packages/dc/aa/956ef729aafb6c8f9c443104c9636489093bb5c61d6b90fc27aa1a865574/sentencepiece-0.2.1-cp314-cp314-win32.whl", hash = "sha256:c415c9de1447e0a74ae3fdb2e52f967cb544113a3a5ce3a194df185cbc1f962f", size = 1096698, upload-time = "2025-08-12T07:00:29.764Z" }, + { url = "https://files.pythonhosted.org/packages/b8/cb/fe400d8836952cc535c81a0ce47dc6875160e5fedb71d2d9ff0e9894c2a6/sentencepiece-0.2.1-cp314-cp314-win_amd64.whl", hash = "sha256:881b2e44b14fc19feade3cbed314be37de639fc415375cefaa5bc81a4be137fd", size = 1155115, upload-time = "2025-08-12T07:00:32.865Z" }, + { url = "https://files.pythonhosted.org/packages/32/89/047921cf70f36c7b6b6390876b2399b3633ab73b8d0cb857e5a964238941/sentencepiece-0.2.1-cp314-cp314-win_arm64.whl", hash = "sha256:2005242a16d2dc3ac5fe18aa7667549134d37854823df4c4db244752453b78a8", size = 1133890, upload-time = "2025-08-12T07:00:34.763Z" }, + { url = "https://files.pythonhosted.org/packages/a1/11/5b414b9fae6255b5fb1e22e2ed3dc3a72d3a694e5703910e640ac78346bb/sentencepiece-0.2.1-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:a19adcec27c524cb7069a1c741060add95f942d1cbf7ad0d104dffa0a7d28a2b", size = 1946081, upload-time = "2025-08-12T07:00:36.97Z" }, + { url = "https://files.pythonhosted.org/packages/77/eb/7a5682bb25824db8545f8e5662e7f3e32d72a508fdce086029d89695106b/sentencepiece-0.2.1-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:e37e4b4c4a11662b5db521def4e44d4d30ae69a1743241412a93ae40fdcab4bb", size = 1327406, upload-time = "2025-08-12T07:00:38.669Z" }, + { url = "https://files.pythonhosted.org/packages/03/b0/811dae8fb9f2784e138785d481469788f2e0d0c109c5737372454415f55f/sentencepiece-0.2.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:477c81505db072b3ab627e7eab972ea1025331bd3a92bacbf798df2b75ea86ec", size = 1254846, upload-time = "2025-08-12T07:00:40.611Z" }, + { url = "https://files.pythonhosted.org/packages/ef/23/195b2e7ec85ebb6a547969f60b723c7aca5a75800ece6cc3f41da872d14e/sentencepiece-0.2.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:010f025a544ef770bb395091d57cb94deb9652d8972e0d09f71d85d5a0816c8c", size = 1315721, upload-time = "2025-08-12T07:00:42.914Z" }, + { url = "https://files.pythonhosted.org/packages/7e/aa/553dbe4178b5f23eb28e59393dddd64186178b56b81d9b8d5c3ff1c28395/sentencepiece-0.2.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:733e59ff1794d26db706cd41fc2d7ca5f6c64a820709cb801dc0ea31780d64ab", size = 1387458, upload-time = "2025-08-12T07:00:44.56Z" }, + { url = "https://files.pythonhosted.org/packages/66/7c/08ff0012507297a4dd74a5420fdc0eb9e3e80f4e88cab1538d7f28db303d/sentencepiece-0.2.1-cp314-cp314t-win32.whl", hash = "sha256:d3233770f78e637dc8b1fda2cd7c3b99ec77e7505041934188a4e7fe751de3b0", size = 1099765, upload-time = "2025-08-12T07:00:46.058Z" }, + { url = "https://files.pythonhosted.org/packages/91/d5/2a69e1ce15881beb9ddfc7e3f998322f5cedcd5e4d244cb74dade9441663/sentencepiece-0.2.1-cp314-cp314t-win_amd64.whl", hash = "sha256:5e4366c97b68218fd30ea72d70c525e6e78a6c0a88650f57ac4c43c63b234a9d", size = 1157807, upload-time = "2025-08-12T07:00:47.673Z" }, + { url = "https://files.pythonhosted.org/packages/f3/16/54f611fcfc2d1c46cbe3ec4169780b2cfa7cf63708ef2b71611136db7513/sentencepiece-0.2.1-cp314-cp314t-win_arm64.whl", hash = "sha256:105e36e75cbac1292642045458e8da677b2342dcd33df503e640f0b457cb6751", size = 1136264, upload-time = "2025-08-12T07:00:49.485Z" }, +] + [[package]] name = "setuptools" version = "81.0.0" @@ -998,6 +1046,7 @@ dependencies = [ { name = "protobuf" }, { name = "pytest" }, { name = "safetensors" }, + { name = "sentencepiece" }, { name = "torch" }, { name = "transformers" }, ] @@ -1010,6 +1059,7 @@ requires-dist = [ { name = "protobuf", specifier = ">=7.34.1" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "safetensors", specifier = ">=0.7.0" }, + { name = "sentencepiece", specifier = ">=0.2.1" }, { name = "torch", specifier = ">=2.11.0" }, { name = "transformers", specifier = ">=5.5.4" }, ] From 6b399899aac4f43d633da04674e78ad9755a1026 Mon Sep 17 00:00:00 2001 From: Lothnic Date: Mon, 27 Apr 2026 20:47:46 +0530 Subject: [PATCH 3/6] fix: corrected weight validation logic to prevent silent initialization of missing weights with random memory --- models/weight_loader.py | 19 ++++++++++--------- test_decode.py | 8 ++++++++ test_decode2.py | 8 ++++++++ test_decode3.py | 8 ++++++++ test_decode4.py | 12 ++++++++++++ test_decode5.py | 5 +++++ test_decode6.py | 1 + 7 files changed, 52 insertions(+), 9 deletions(-) create mode 100644 test_decode.py create mode 100644 test_decode2.py create mode 100644 test_decode3.py create mode 100644 test_decode4.py create mode 100644 test_decode5.py create mode 100644 test_decode6.py diff --git a/models/weight_loader.py b/models/weight_loader.py index 3df02c9..aad389c 100644 --- a/models/weight_loader.py +++ b/models/weight_loader.py @@ -195,7 +195,15 @@ def load_hf_model(model_id: str, device: str = "cuda", dtype: torch.dtype = torc module.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) module.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - # 6. Ensure all remaining meta tensors are materialized on the correct device + # 6. Check for missing parameters and buffers before materializing them + missing_params = [name for name, param in model.named_parameters() if param.is_meta] + missing_buffers = [name for name, buf in model.named_buffers() if buf.is_meta] + + real_missing = [k for k in missing_params if "lm_head" not in k] + if real_missing or missing_buffers: + raise ValueError(f"Missing weights/buffers in checkpoint! params={real_missing}, buffers={missing_buffers}") + + # 7. Ensure all remaining allowed meta tensors (e.g. lm_head) are materialized for param in model.parameters(): if param.is_meta: param.data = torch.empty_like(param, device=device) @@ -203,18 +211,11 @@ def load_hf_model(model_id: str, device: str = "cuda", dtype: torch.dtype = torc if buffer.is_meta: buffer.data = torch.empty_like(buffer, device=device) - # 7. Move model to target device/dtype + # 8. Move model to target device/dtype # Note: for quantized models, bnb parameters handle their own dtype, # model.to() will skip them automatically model.to(device, dtype=dtype) - # 8. Check for missing parameters - missing_keys = [name for name, param in model.named_parameters() if param.is_meta] - if missing_keys: - real_missing = [k for k in missing_keys if "lm_head" not in k] - if real_missing: - print(f"Missing: {real_missing}") - config.device = device model.eval() if torch.cuda.is_available(): diff --git a/test_decode.py b/test_decode.py new file mode 100644 index 0000000..31903fa --- /dev/null +++ b/test_decode.py @@ -0,0 +1,8 @@ +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + +tokens = tokenizer(" Hello world", return_tensors="pt").input_ids[0] +print(tokens) +print("Decode all:", repr(tokenizer.decode(tokens))) +print("Decode first token:", repr(tokenizer.decode(tokens[1:2]))) +print("Decode second token:", repr(tokenizer.decode(tokens[2:3]))) diff --git a/test_decode2.py b/test_decode2.py new file mode 100644 index 0000000..84cffd9 --- /dev/null +++ b/test_decode2.py @@ -0,0 +1,8 @@ +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + +tokens = tokenizer(" Hello world", return_tensors="pt").input_ids[0] +print(tokens) +t1 = tokenizer.decode(tokens[1:2]) +t2 = tokenizer.decode(tokens[2:3]) +print(repr(t1 + t2)) diff --git a/test_decode3.py b/test_decode3.py new file mode 100644 index 0000000..5e004d8 --- /dev/null +++ b/test_decode3.py @@ -0,0 +1,8 @@ +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + +tokens = tokenizer(" Hello world", return_tensors="pt").input_ids[0] +print(tokens) +t1 = tokenizer.decode(tokens[1:2], clean_up_tokenization_spaces=False) +t2 = tokenizer.decode(tokens[2:3], clean_up_tokenization_spaces=False) +print(repr(t1 + t2)) diff --git a/test_decode4.py b/test_decode4.py new file mode 100644 index 0000000..d305683 --- /dev/null +++ b/test_decode4.py @@ -0,0 +1,12 @@ +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + +tokens = tokenizer(" Hello world", return_tensors="pt").input_ids[0] +print("tokens:", tokens) + +prev_token_text = tokenizer.decode(tokens[1:2]) +full_text = tokenizer.decode(tokens[1:3]) +new_text = full_text[len(prev_token_text):] +print("prev_token_text:", repr(prev_token_text)) +print("full_text:", repr(full_text)) +print("new_text:", repr(new_text)) diff --git a/test_decode5.py b/test_decode5.py new file mode 100644 index 0000000..4c1157a --- /dev/null +++ b/test_decode5.py @@ -0,0 +1,5 @@ +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") + +tokens = tokenizer(" Hello world", return_tensors="pt").input_ids[0] +print(tokenizer.convert_ids_to_tokens(tokens)) diff --git a/test_decode6.py b/test_decode6.py new file mode 100644 index 0000000..87694e5 --- /dev/null +++ b/test_decode6.py @@ -0,0 +1 @@ +# Just a dummy action to conform with tool usage if needed, though I can just answer. From 8b9769052b5f628580abe25dc7bf3d6b9a793670 Mon Sep 17 00:00:00 2001 From: Lothnic Date: Mon, 27 Apr 2026 20:52:31 +0530 Subject: [PATCH 4/6] removed temporary decoding files --- test_decode.py | 8 -------- test_decode2.py | 8 -------- test_decode3.py | 8 -------- test_decode4.py | 12 ------------ test_decode5.py | 5 ----- test_decode6.py | 1 - 6 files changed, 42 deletions(-) delete mode 100644 test_decode.py delete mode 100644 test_decode2.py delete mode 100644 test_decode3.py delete mode 100644 test_decode4.py delete mode 100644 test_decode5.py delete mode 100644 test_decode6.py diff --git a/test_decode.py b/test_decode.py deleted file mode 100644 index 31903fa..0000000 --- a/test_decode.py +++ /dev/null @@ -1,8 +0,0 @@ -from transformers import AutoTokenizer -tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - -tokens = tokenizer(" Hello world", return_tensors="pt").input_ids[0] -print(tokens) -print("Decode all:", repr(tokenizer.decode(tokens))) -print("Decode first token:", repr(tokenizer.decode(tokens[1:2]))) -print("Decode second token:", repr(tokenizer.decode(tokens[2:3]))) diff --git a/test_decode2.py b/test_decode2.py deleted file mode 100644 index 84cffd9..0000000 --- a/test_decode2.py +++ /dev/null @@ -1,8 +0,0 @@ -from transformers import AutoTokenizer -tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - -tokens = tokenizer(" Hello world", return_tensors="pt").input_ids[0] -print(tokens) -t1 = tokenizer.decode(tokens[1:2]) -t2 = tokenizer.decode(tokens[2:3]) -print(repr(t1 + t2)) diff --git a/test_decode3.py b/test_decode3.py deleted file mode 100644 index 5e004d8..0000000 --- a/test_decode3.py +++ /dev/null @@ -1,8 +0,0 @@ -from transformers import AutoTokenizer -tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - -tokens = tokenizer(" Hello world", return_tensors="pt").input_ids[0] -print(tokens) -t1 = tokenizer.decode(tokens[1:2], clean_up_tokenization_spaces=False) -t2 = tokenizer.decode(tokens[2:3], clean_up_tokenization_spaces=False) -print(repr(t1 + t2)) diff --git a/test_decode4.py b/test_decode4.py deleted file mode 100644 index d305683..0000000 --- a/test_decode4.py +++ /dev/null @@ -1,12 +0,0 @@ -from transformers import AutoTokenizer -tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - -tokens = tokenizer(" Hello world", return_tensors="pt").input_ids[0] -print("tokens:", tokens) - -prev_token_text = tokenizer.decode(tokens[1:2]) -full_text = tokenizer.decode(tokens[1:3]) -new_text = full_text[len(prev_token_text):] -print("prev_token_text:", repr(prev_token_text)) -print("full_text:", repr(full_text)) -print("new_text:", repr(new_text)) diff --git a/test_decode5.py b/test_decode5.py deleted file mode 100644 index 4c1157a..0000000 --- a/test_decode5.py +++ /dev/null @@ -1,5 +0,0 @@ -from transformers import AutoTokenizer -tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - -tokens = tokenizer(" Hello world", return_tensors="pt").input_ids[0] -print(tokenizer.convert_ids_to_tokens(tokens)) diff --git a/test_decode6.py b/test_decode6.py deleted file mode 100644 index 87694e5..0000000 --- a/test_decode6.py +++ /dev/null @@ -1 +0,0 @@ -# Just a dummy action to conform with tool usage if needed, though I can just answer. From 47eee60a335fdc67d8decbbff592e2bc617a981a Mon Sep 17 00:00:00 2001 From: Lothnic Date: Mon, 27 Apr 2026 21:14:43 +0530 Subject: [PATCH 5/6] improved logging in weight_loader --- models/weight_loader.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/models/weight_loader.py b/models/weight_loader.py index aad389c..87c271e 100644 --- a/models/weight_loader.py +++ b/models/weight_loader.py @@ -6,7 +6,12 @@ from safetensors.torch import load_file from models.llama import LlamaConfig, LlamaForCausalLM from models.qwen3 import QwenForCausalLM +import logging from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError + +# Initialize logger +logger = logging.getLogger(__name__) # REGISTRY @@ -49,7 +54,7 @@ def _download(filename: str, *, must_be_local: bool) -> str: """Try local first, then online if allowed.""" try: return hf_hub_download(repo_id=model_id, filename=filename, local_files_only=True) - except Exception: + except EntryNotFoundError: if must_be_local: raise return hf_hub_download(repo_id=model_id, filename=filename, local_files_only=False) @@ -57,8 +62,12 @@ def _download(filename: str, *, must_be_local: bool) -> str: # 1. Try single-file model try: return [_download("model.safetensors", must_be_local=local_only)] - except Exception: + except EntryNotFoundError: + # Expected if model is sharded pass + except Exception as e: + logger.error(f"Error checking for single-file model: {e}") + raise # 2. Multi-shard model — get the index first, then each shard index_path = _download("model.safetensors.index.json", must_be_local=local_only) From 5560688796b0596ccb5ebf43f05c6f0921136f0b Mon Sep 17 00:00:00 2001 From: Lothnic Date: Mon, 27 Apr 2026 21:17:22 +0530 Subject: [PATCH 6/6] fix: removed lm_head exception so as to raise ValueError on missing lm_head --- models/weight_loader.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/models/weight_loader.py b/models/weight_loader.py index 87c271e..3da8023 100644 --- a/models/weight_loader.py +++ b/models/weight_loader.py @@ -208,17 +208,16 @@ def load_hf_model(model_id: str, device: str = "cuda", dtype: torch.dtype = torc missing_params = [name for name, param in model.named_parameters() if param.is_meta] missing_buffers = [name for name, buf in model.named_buffers() if buf.is_meta] - real_missing = [k for k in missing_params if "lm_head" not in k] - if real_missing or missing_buffers: - raise ValueError(f"Missing weights/buffers in checkpoint! params={real_missing}, buffers={missing_buffers}") + if missing_params or missing_buffers: + raise ValueError(f"Missing weights/buffers in checkpoint! params={missing_params}, buffers={missing_buffers}") - # 7. Ensure all remaining allowed meta tensors (e.g. lm_head) are materialized + # 7. Ensure all remaining allowed meta tensors are materialized deterministically for param in model.parameters(): if param.is_meta: - param.data = torch.empty_like(param, device=device) + param.data = torch.zeros_like(param, device=device) for buffer in model.buffers(): if buffer.is_meta: - buffer.data = torch.empty_like(buffer, device=device) + buffer.data = torch.zeros_like(buffer, device=device) # 8. Move model to target device/dtype # Note: for quantized models, bnb parameters handle their own dtype,