|
22 | 22 |
|
23 | 23 | apis = ['sdapi', 'openai', 'openai-module'] |
24 | 24 |
|
| 25 | +def extract_lora_tags(prompt): |
| 26 | + |
| 27 | + pattern = r'<lora:([^:>]+):([^>]+)>' |
| 28 | + lora_data = [] |
| 29 | + |
| 30 | + matches = list(re.finditer(pattern, prompt)) |
| 31 | + |
| 32 | + for match in matches: |
| 33 | + raw_path = match.group(1) |
| 34 | + raw_mul = match.group(2) |
| 35 | + try: |
| 36 | + mul = float(raw_mul) |
| 37 | + except ValueError: |
| 38 | + continue |
| 39 | + |
| 40 | + is_high_noise = False |
| 41 | + prefix = "|high_noise|" |
| 42 | + if raw_path.startswith(prefix): |
| 43 | + raw_path = raw_path[len(prefix):] |
| 44 | + is_high_noise = True |
| 45 | + |
| 46 | + lora_data.append({ |
| 47 | + 'name': raw_path, |
| 48 | + 'multiplier': mul, |
| 49 | + 'is_high_noise': is_high_noise, |
| 50 | + }) |
| 51 | + |
| 52 | + prompt = prompt.replace(match.group(0), "", 1) |
| 53 | + |
| 54 | + return prompt, lora_data |
| 55 | + |
25 | 56 |
|
26 | 57 | def parse_arguments(): |
27 | 58 | ap = argparse.ArgumentParser( |
@@ -137,8 +168,13 @@ def parse_arguments(): |
137 | 168 | if not args_dict.get("server_url", "").strip(): |
138 | 169 | ap.error("--server-url not provided and SD_SERVER_URL env var not found") |
139 | 170 |
|
140 | | - if not args_dict.get("prompt", "").strip(): |
| 171 | + prompt = args_dict.get("prompt", "").strip() |
| 172 | + prompt, lora = extract_lora_tags(prompt) |
| 173 | + if not prompt: |
141 | 174 | ap.error("argument -p/--prompt must be non‑empty") |
| 175 | + args_dict["prompt"] = prompt |
| 176 | + if lora: |
| 177 | + args_dict["lora"] = lora |
142 | 178 |
|
143 | 179 | util_keys = {'verbose', 'server_url', 'output', 'output_begin_idx', 'api', 'init_img', 'mask', 'ref_image', 'output_format'} |
144 | 180 |
|
@@ -352,7 +388,10 @@ def truncate_for_debug(obj, max_length=512): |
352 | 388 | def do_request(url, data, headers): |
353 | 389 |
|
354 | 390 | if HAS_REQUESTS: |
355 | | - response = requests.post(url, headers=headers, data=data, timeout=30) |
| 391 | + response = requests.post(url, headers=headers, data=data, timeout=600) |
| 392 | + if response.status_code != 200: |
| 393 | + print(f"HTTP {response.status_code}: {response.reason}") |
| 394 | + print(f" {response.text}") |
356 | 395 | response.raise_for_status() |
357 | 396 | return response.text |
358 | 397 |
|
@@ -396,6 +435,27 @@ def main_sdapi(util_opts, gen_opts, verbose=False): |
396 | 435 | endpoint = urllib.parse.urljoin(server_url, "sdapi/v1/txt2img") |
397 | 436 |
|
398 | 437 | api_parameters = build_sdapi_parameters(gen_opts, util_opts, image_opts) |
| 438 | + lora = gen_opts.get('lora') |
| 439 | + if lora: |
| 440 | + # TODO: refactor, error handling, is_high_noise |
| 441 | + lora_list = json.loads(requests.get(urllib.parse.urljoin(server_url, "sdapi/v1/loras")).text) |
| 442 | + #print(f"remote lora list: {json.dumps(lora_list, indent=2)}") |
| 443 | + lora_map = {l.get("name"): l.get("path") for l in lora_list} |
| 444 | + req_lora = [] |
| 445 | + print(f"requesting LoRAs") |
| 446 | + for llora in lora: |
| 447 | + name = llora['name'] |
| 448 | + rlora = lora_map.get(name) |
| 449 | + if rlora: |
| 450 | + entry = {'path': rlora, 'multiplier': llora["multiplier"]} |
| 451 | + if verbose: |
| 452 | + print(f" lora '{name}' mapped to '{rlora}'") |
| 453 | + req_lora.append(entry) |
| 454 | + else: |
| 455 | + if verbose: |
| 456 | + print(f" warning: lora '{name}' not found on remote") |
| 457 | + if req_lora: |
| 458 | + api_parameters["lora"] = req_lora |
399 | 459 |
|
400 | 460 | if verbose: |
401 | 461 | print(f"Using sdapi") |
|
0 commit comments