本篇使用 Google Colab 完成训练与推理链路,GPU 使用 T4。
目标是实现bi数据解读的双模型路由:默认走 0.8B,低置信度或规则命中时回退到2B。 从最终的结果来看,0.8B模型已经能够胜任大部分的解读任务,只有在输入特别复杂或者模型输出质量较低时才会回退到 2B,这样在保证解读质量的同时最大化效率。
目录
- 1.安装依赖
- 1.1重启运行时
- 2.重启后检查版本
- 3.检查 GPU
- 4.生成数据
- 5.swanlab(如果没有账号,可以不用执行这个模块)
- 6.准备测试数据
- 7.模型懒加载
- 8.挂载 LoRA Adapter(如果没有训练好的,可以不用挂载)
- 9.路由规则与质量评分
- 10.推理函数(单模型)
- 11.路由执行(0.8B主路+2B回退)
- 12.单条测试
- 13.批量路由+swanlab记录
- 14.保存结果
代码实现
1. 安装依赖
!pip -q uninstall -y transformers!pip -q install -U --no-cache-dir --force-reinstall --no-deps git+https://github.com/huggingface/transformers.git@main!pip -q install -U datasets peft accelerate swanlab!pip -q install -U swanlab1.1 重启运行时
import osos.kill(os.getpid(), 9)2. 重启后检查版本
import transformersfrom transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
print("transformers:", transformers.__version__)print("path:", transformers.__file__)print("qwen3_5 supported:", "qwen3_5" in CONFIG_MAPPING_NAMES)3. 检查 GPU
!nvidia-smi4. 生成数据
import jsonimport randomfrom datetime import datetime, timedeltafrom pathlib import Path
random.seed(42)
N = 300DATA_PATH = "/content/bi_sft_300.json"
indicators = [ ("GMV", "万元", 800, 5000), ("订单数", "单", 3000, 20000), ("转化率", "%", 1.0, 6.0), ("退款率", "%", 0.5, 15.0), ("客单价", "元", 80, 600), ("新客数", "人", 200, 5000), ("ROAS", "", 1.0, 8.0), ("毛利率", "%", 8.0, 40.0),]
channels = ["信息流", "搜索", "站内", "社媒", "自然流量"]regions = ["华东", "华南", "华北", "华中", "西南"]categories = ["3C数码", "美妆", "食品", "家居", "服饰"]
def pick_metric(metrics, name): return next(x for x in metrics if x["name"] == name)
samples = []base_date = datetime(2026, 3, 6)
for i in range(N): day = (base_date - timedelta(days=i)).strftime("%Y-%m-%d")
metrics = [] for name, unit, lo, hi in indicators: value = round(random.uniform(lo, hi), 2) mom = round(random.uniform(-0.30, 0.30), 3) yoy = round(random.uniform(-0.40, 0.40), 3) metrics.append({ "name": name, "unit": unit, "value": value, "mom": mom, "yoy": yoy })
payload = { "date": day, "metrics": metrics, "dimensions": { "channel": [{"name": c, "gmv_mom": round(random.uniform(-0.5, 0.5), 3)} for c in random.sample(channels, 3)], "region": [{"name": r, "gmv_mom": round(random.uniform(-0.5, 0.5), 3)} for r in random.sample(regions, 3)], "category": [{"name": c, "gmv_mom": round(random.uniform(-0.5, 0.5), 3)} for c in random.sample(categories, 3)], } }
gmv = pick_metric(metrics, "GMV") conv = pick_metric(metrics, "转化率") refund = pick_metric(metrics, "退款率") roas = pick_metric(metrics, "ROAS") margin = pick_metric(metrics, "毛利率")
highlights, risks, actions = [], [], []
if gmv["mom"] >= 0.08: highlights.append(f"GMV环比上升{gmv['mom']*100:.1f}%") if conv["mom"] >= 0.05: highlights.append(f"转化率环比提升{conv['mom']*100:.1f}%") if refund["mom"] <= -0.02: highlights.append("退款率环比下降")
if gmv["mom"] <= -0.08: risks.append(f"GMV环比下降{abs(gmv['mom']*100):.1f}%") actions.append("排查下滑渠道与活动投放") if conv["mom"] <= -0.04: risks.append(f"转化率环比下降{abs(conv['mom']*100):.1f}%") actions.append("优化落地页和转化链路") if refund["value"] >= 8: risks.append("退款率偏高") actions.append("复盘高退款SKU和售后流程") if roas["value"] <= 2: risks.append("ROAS偏低") actions.append("优化投放结构和素材") if margin["mom"] <= -0.03: risks.append("毛利率下滑") actions.append("关注折扣与成本变化")
if not highlights: highlights = ["整体波动可控"] if not risks: risks = ["暂无明显风险"] if not actions: actions = ["持续观察核心指标趋势"]
output_obj = { "summary": f"{day} BI日报:GMV环比{gmv['mom']*100:.1f}%,转化率环比{conv['mom']*100:.1f}%,退款率{refund['value']}%。", "highlights": highlights, "risks": risks, "actions": actions }
samples.append({ "system": "你是BI分析助手。", "instruction": "请解读以下BI日报,输出JSON结论(summary/highlights/risks/actions)。", "input": json.dumps(payload, ensure_ascii=False), "output": json.dumps(output_obj, ensure_ascii=False) })
Path(DATA_PATH).write_text(json.dumps(samples, ensure_ascii=False, indent=2), encoding="utf-8")print("生成完成:", DATA_PATH, "样本数:", len(samples))5.swanlab
SWANLAB_ON = Falsetry: import swanlab swanlab.init( project="bi-router", experiment_name="qwen35-08b-2b-router", config={ "small_model": "Qwen/Qwen3.5-0.8B", "large_model": "Qwen/Qwen3.5-2B", "route_policy": "complexity + quality fallback", }, ) SWANLAB_ON = True print("swanlab enabled")except Exception as e: print("swanlab disabled:", e)6. 准备测试数据
import jsonimport randomfrom datetime import datetime, timedeltafrom pathlib import Path
DATA_PATH = "/content/drive/MyDrive/bi_sft_300.json"if not Path(DATA_PATH).exists(): DATA_PATH = "/content/bi_router_eval.json"
if not Path(DATA_PATH).exists(): random.seed(42) rows = [] base_date = datetime(2026, 3, 6) for i in range(50): day = (base_date - timedelta(days=i)).strftime("%Y-%m-%d") payload = { "date": day, "metrics": [ {"name": "GMV", "value": round(random.uniform(800, 3000), 2), "mom": round(random.uniform(-0.3, 0.2), 3)}, {"name": "订单数", "value": random.randint(3000, 15000), "mom": round(random.uniform(-0.2, 0.2), 3)}, {"name": "转化率", "value": round(random.uniform(0.01, 0.05), 4), "mom": round(random.uniform(-0.02, 0.01), 4)}, {"name": "退款率", "value": round(random.uniform(0.02, 0.15), 4), "mom": round(random.uniform(-0.01, 0.03), 4)}, {"name": "ROAS", "value": round(random.uniform(1.2, 4.5), 3), "mom": round(random.uniform(-0.4, 0.2), 3)}, ], "dimensions": { "channel": [ {"name": "信息流", "gmv_mom": round(random.uniform(-0.4, 0.2), 3)}, {"name": "搜索", "gmv_mom": round(random.uniform(-0.3, 0.3), 3)}, {"name": "站内", "gmv_mom": round(random.uniform(-0.2, 0.2), 3)}, ], "region": [ {"name": "华东", "gmv_mom": round(random.uniform(-0.4, 0.2), 3)}, {"name": "华南", "gmv_mom": round(random.uniform(-0.2, 0.3), 3)}, {"name": "华北", "gmv_mom": round(random.uniform(-0.3, 0.2), 3)}, ], }, } rows.append({"input": json.dumps(payload, ensure_ascii=False)})
Path(DATA_PATH).write_text(json.dumps(rows, ensure_ascii=False, indent=2), encoding="utf-8")
print("DATA_PATH:", DATA_PATH)7. 模型懒加载
import torchfrom transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_IDS = { "small": "Qwen/Qwen3.5-0.8B", "large": "Qwen/Qwen3.5-2B",}
tokenizers = {}models = {}
def get_model(name: str): if name in models: return tokenizers[name], models[name]
model_id = MODEL_IDS[name] tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, device_map="auto", dtype=torch.float16, ) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizers[name] = tokenizer models[name] = model return tokenizer, model
print("lazy loader ready")8. (可选)挂载 LoRA Adapter
# 如果你没有微调过,就保持注释# from peft import PeftModel# small_adapter_path = "/content/drive/MyDrive/qwen35_08b_bi_lora"# large_adapter_path = "/content/drive/MyDrive/qwen35_2b_bi_lora"## tok_s, mdl_s = get_model("small")# models["small"] = PeftModel.from_pretrained(mdl_s, small_adapter_path)## tok_l, mdl_l = get_model("large")# models["large"] = PeftModel.from_pretrained(mdl_l, large_adapter_path)9. 路由规则与质量评分
import json, re
REQUIRED_KEYS = ["summary", "highlights", "risks", "actions"]
def safe_json_load(text: str): if not text or not text.strip(): return None t = text.strip() t = re.sub(r"^```json\s*", "", t, flags=re.I) t = re.sub(r"^```", "", t) t = re.sub(r"```$", "", t).strip()
# 先直接解析 try: return json.loads(t) except Exception: pass
# 再提取第一个 {...} m = re.search(r"\{[\s\S]*\}", t) if not m: return None try: return json.loads(m.group(0)) except Exception: return None
def output_quality_score(obj): if not isinstance(obj, dict): return 0.0 score = 0.0 for k in REQUIRED_KEYS: if k in obj: score += 0.2 if isinstance(obj.get("summary"), str) and len(obj["summary"]) >= 20: score += 0.2 return min(score, 1.0)10. 推理函数(单模型)
def generate_bi_json(payload: dict, model_name: str, max_new_tokens: int = 320): tokenizer, model = get_model(model_name)
system = ( "你是BI分析助手。只输出一个JSON对象,不要任何解释。" "JSON键严格为:summary, highlights, risks, actions。" "highlights/risks/actions 每个数组最多3条,每条不超过40字。") user = "请解读以下BI日报并按指定JSON输出:" + json.dumps(payload, ensure_ascii=False)
messages = [ {"role": "system", "content": system}, {"role": "user", "content": user}, ]
inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True, enable_thinking=False ).to(model.device)
with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=tokenizer.eos_token_id )
text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() obj = safe_json_load(text) quality = output_quality_score(obj)
return {"model": model_name, "raw_text": text, "parsed": obj, "quality": quality}11. 路由执行(0.8B主路+2B回退)
def route_predict(payload: dict, quality_threshold: float = 0.65, direct_large_threshold: float = 0.95): # 复杂度特别高才直走2B;否则先走0.8B if should_direct_large(payload, threshold=direct_large_threshold): large = generate_bi_json(payload, "large") large["route_reason"] = "complex_input_direct_large" return large
small = generate_bi_json(payload, "small") if small["quality"] >= quality_threshold: small["route_reason"] = "small_pass" return small
large = generate_bi_json(payload, "large") large["route_reason"] = f"fallback_from_small_quality_{small['quality']:.2f}" large["small_preview"] = small return large12. 单条测试
test_input = { "date": "2026-03-04", "metrics": [ {"name": "GMV", "value": 980, "mom": -0.18, "yoy": -0.05}, {"name": "订单数", "value": 7200, "mom": -0.09, "yoy": 0.03}, {"name": "转化率", "value": 0.018, "mom": -0.006, "yoy": -0.002}, {"name": "退款率", "value": 0.11, "mom": 0.04, "yoy": 0.02}, {"name": "ROAS", "value": 1.6, "mom": -0.25, "yoy": -0.18}, ], "dimensions": { "channel": [ {"name": "信息流", "gmv_mom": -0.32}, {"name": "搜索", "gmv_mom": 0.06}, {"name": "站内", "gmv_mom": -0.08}, ], "region": [ {"name": "华东", "gmv_mom": -0.21}, {"name": "华南", "gmv_mom": 0.04}, {"name": "华北", "gmv_mom": -0.12}, ], },}
result = route_predict(test_input)print("model:", result["model"])print("reason:", result["route_reason"])print("quality:", result["quality"])print("raw_text:", result["raw_text"])print("parsed:", result["parsed"])13. 批量路由+swanlab记录
import pandas as pdfrom pathlib import Path
rows = json.loads(Path(DATA_PATH).read_text(encoding="utf-8"))outputs = []for row in rows[:20]: payload = json.loads(row["input"]) if isinstance(row.get("input"), str) else row pred = route_predict(payload) outputs.append({ "model": pred["model"], "route_reason": pred["route_reason"], "quality": pred["quality"], "parsed_ok": isinstance(pred["parsed"], dict), })
df = pd.DataFrame(outputs)print(df.head())print("\nmodel count:")print(df["model"].value_counts())print("\nreason count:")print(df["route_reason"].value_counts())
if SWANLAB_ON: model_counts = df["model"].value_counts().to_dict() reason_counts = df["route_reason"].value_counts().to_dict() swanlab.log({ "avg_quality": float(df["quality"].mean()), "parsed_ok_rate": float(df["parsed_ok"].mean()), "small_count": int(model_counts.get("small", 0)), "large_count": int(model_counts.get("large", 0)), "route_reason_counts": reason_counts, }) print("swanlab logged")14. 保存结果
from pathlib import Pathimport json
save_path = Path("/content/bi_router_results.jsonl")with save_path.open("w", encoding="utf-8") as f: for x in outputs: f.write(json.dumps(x, ensure_ascii=False) + "\n")
print("saved:", save_path) Thanks for reading!