support fro multiple litellm auth methods
This commit is contained in:
@@ -146,6 +146,7 @@ package.json
|
||||
|
||||
# Development docs
|
||||
connpy_roadmap.md
|
||||
testnew/
|
||||
testall/
|
||||
testremote/
|
||||
*.db
|
||||
@@ -170,3 +171,6 @@ MULTI_USER_IMPLEMENTATION_STEPS.md
|
||||
#themes
|
||||
nord.yml
|
||||
theme.py
|
||||
|
||||
#ai auth
|
||||
auth.json
|
||||
|
||||
+49
-16
@@ -108,7 +108,7 @@ class ai:
|
||||
r'^systemctl\s+status\s+', r'^journalctl\s+'
|
||||
]
|
||||
|
||||
def __init__(self, config, org=None, api_key=None, engineer_model=None, architect_model=None, engineer_api_key=None, architect_api_key=None, console=None, confirm_handler=None, trust=False):
|
||||
def __init__(self, config, org=None, api_key=None, engineer_model=None, architect_model=None, engineer_api_key=None, architect_api_key=None, console=None, confirm_handler=None, trust=False, engineer_auth=None, architect_auth=None, **kwargs):
|
||||
self.config = config
|
||||
self.console = console or printer.console
|
||||
self.confirm_handler = confirm_handler or self._local_confirm_handler
|
||||
@@ -127,6 +127,29 @@ class ai:
|
||||
self.engineer_key = engineer_api_key or aiconfig.get("engineer_api_key")
|
||||
self.architect_key = architect_api_key or aiconfig.get("architect_api_key")
|
||||
|
||||
# Auth configurations (Prioridad: Argumento -> Config)
|
||||
self.engineer_auth = engineer_auth if engineer_auth is not None else aiconfig.get("engineer_auth")
|
||||
if self.engineer_auth is None:
|
||||
self.engineer_auth = {}
|
||||
elif not isinstance(self.engineer_auth, dict):
|
||||
self.engineer_auth = {}
|
||||
|
||||
self.architect_auth = architect_auth if architect_auth is not None else aiconfig.get("architect_auth")
|
||||
if self.architect_auth is None:
|
||||
self.architect_auth = {}
|
||||
elif not isinstance(self.architect_auth, dict):
|
||||
self.architect_auth = {}
|
||||
|
||||
# Backward compatibility fallbacks: only inject api_key if the auth dict is empty/not configured
|
||||
if self.engineer_key and not self.engineer_auth:
|
||||
self.engineer_auth["api_key"] = self.engineer_key
|
||||
if self.architect_key and not self.architect_auth:
|
||||
self.architect_auth["api_key"] = self.architect_key
|
||||
|
||||
# Strategic Reasoning Engine (Architect) availability
|
||||
is_architect_keyless = "vertex" in self.architect_model.lower() or "ollama" in self.architect_model.lower() or "local" in self.architect_model.lower()
|
||||
self.has_architect = bool(self.architect_key or self.architect_auth or is_architect_keyless)
|
||||
|
||||
# Custom Trusted Commands Regexes
|
||||
custom_trusted = aiconfig.get("trusted_commands", [])
|
||||
if isinstance(custom_trusted, str):
|
||||
@@ -172,7 +195,7 @@ class ai:
|
||||
|
||||
# Prompts base agnósticos
|
||||
architect_instructions = ""
|
||||
if self.architect_key:
|
||||
if self.has_architect:
|
||||
architect_instructions = """
|
||||
CRITICAL - CONSULT vs ESCALATE:
|
||||
- ALWAYS use 'consult_architect' for: Configuration planning, design decisions, complex troubleshooting.
|
||||
@@ -188,7 +211,7 @@ class ai:
|
||||
else:
|
||||
architect_instructions = """
|
||||
CRITICAL - ARCHITECT UNAVAILABLE:
|
||||
- The Strategic Reasoning Engine (Architect) is currently UNAVAILABLE because its API key is not configured.
|
||||
- The Strategic Reasoning Engine (Architect) is currently UNAVAILABLE because its API key or authentication is not configured.
|
||||
- DO NOT attempt to consult or escalate to the architect.
|
||||
- If the user asks to consult the architect, inform them that the Architect is offline and offer to help them directly to the best of your abilities.
|
||||
"""
|
||||
@@ -294,15 +317,19 @@ class ai:
|
||||
if status_formatter:
|
||||
self.tool_status_formatters[name] = status_formatter
|
||||
|
||||
def _stream_completion(self, model, messages, tools, api_key, status=None, label="", debug=False, chunk_callback=None, **kwargs):
|
||||
def _stream_completion(self, model, messages, tools, api_key=None, status=None, label="", debug=False, chunk_callback=None, auth=None, **kwargs):
|
||||
"""Stream a completion call, rendering styled Markdown in real-time.
|
||||
|
||||
Returns (response, streamed) where:
|
||||
- response: reconstructed ModelResponse (same as non-streaming)
|
||||
- streamed: True if text was rendered to console during streaming
|
||||
"""
|
||||
auth_dict = auth if auth is not None else {}
|
||||
if api_key and "api_key" not in auth_dict:
|
||||
auth_dict = auth_dict.copy()
|
||||
auth_dict["api_key"] = api_key
|
||||
|
||||
stream_resp = completion(model=model, messages=messages, tools=tools, api_key=api_key, stream=True, **kwargs)
|
||||
stream_resp = completion(model=model, messages=messages, tools=tools, stream=True, **auth_dict, **kwargs)
|
||||
|
||||
chunks = []
|
||||
full_content = ""
|
||||
@@ -745,7 +772,7 @@ class ai:
|
||||
|
||||
try:
|
||||
safe_messages = self._sanitize_messages(messages)
|
||||
response = completion(model=self.engineer_model, messages=safe_messages, tools=tools, api_key=self.engineer_key)
|
||||
response = completion(model=self.engineer_model, messages=safe_messages, tools=tools, **self.engineer_auth)
|
||||
except Exception as e:
|
||||
if status: status.stop()
|
||||
raise ValueError(f"Engineer failed to connect: {str(e)}")
|
||||
@@ -981,8 +1008,9 @@ class ai:
|
||||
|
||||
@MethodHook
|
||||
def ask(self, user_input, dryrun=False, chat_history=None, status=None, debug=False, stream=True, session_id=None, chunk_callback=None):
|
||||
if not self.engineer_key:
|
||||
raise ValueError("Engineer API key not configured. Use 'connpy config --engineer-api-key <key>' to set it.")
|
||||
is_engineer_keyless = "vertex" in self.engineer_model.lower() or "ollama" in self.engineer_model.lower() or "local" in self.engineer_model.lower()
|
||||
if not self.engineer_key and not self.engineer_auth and not is_engineer_keyless:
|
||||
raise ValueError("Engineer API key or authentication not configured. Use 'connpy config --engineer-auth <auth>' to set it.")
|
||||
|
||||
if chat_history is None: chat_history = []
|
||||
|
||||
@@ -1031,6 +1059,7 @@ class ai:
|
||||
tools = self._get_architect_tools() if current_brain == "architect" else self._get_engineer_tools()
|
||||
model = self.architect_model if current_brain == "architect" else self.engineer_model
|
||||
key = self.architect_key if current_brain == "architect" else self.engineer_key
|
||||
current_auth = self.architect_auth if current_brain == "architect" else self.engineer_auth
|
||||
|
||||
# Estructura optimizada para Prompt Caching (Solo para Anthropic directo, Vertex tiene reglas distintas)
|
||||
if "claude" in model.lower() and "vertex" not in model.lower():
|
||||
@@ -1090,12 +1119,12 @@ class ai:
|
||||
safe_messages = self._sanitize_messages(messages)
|
||||
if stream:
|
||||
response, streamed_response = self._stream_completion(
|
||||
model=model, messages=safe_messages, tools=tools, api_key=key,
|
||||
model=model, messages=safe_messages, tools=tools, auth=current_auth,
|
||||
status=status, label=label, debug=debug, num_retries=3,
|
||||
chunk_callback=chunk_callback
|
||||
)
|
||||
else:
|
||||
response = completion(model=model, messages=safe_messages, tools=tools, api_key=key, num_retries=3)
|
||||
response = completion(model=model, messages=safe_messages, tools=tools, num_retries=3, **current_auth)
|
||||
except Exception as e:
|
||||
if current_brain == "architect":
|
||||
if status: status.update("[unavailable]Architect unavailable! Falling back to Engineer...")
|
||||
@@ -1104,6 +1133,7 @@ class ai:
|
||||
model = self.engineer_model
|
||||
tools = self._get_engineer_tools()
|
||||
key = self.engineer_key
|
||||
current_auth = self.engineer_auth
|
||||
# Rebuild messages with Engineer system prompt and original user request
|
||||
messages = [{"role": "system", "content": self.engineer_system_prompt}]
|
||||
# Add chat history if exists (excluding system prompt)
|
||||
@@ -1196,6 +1226,7 @@ class ai:
|
||||
model = self.architect_model
|
||||
tools = self._get_architect_tools()
|
||||
key = self.architect_key
|
||||
current_auth = self.architect_auth
|
||||
messages[0] = {"role": "system", "content": self.architect_system_prompt}
|
||||
# Prepare handover context to inject AFTER all tool responses
|
||||
handover_msg = f"HANDOVER FROM EXECUTION ENGINE\n\nReason: {args['reason']}\n\nContext: {args['context']}\n\nYou are now in control of this conversation."
|
||||
@@ -1217,6 +1248,7 @@ class ai:
|
||||
model = self.engineer_model
|
||||
tools = self._get_engineer_tools()
|
||||
key = self.engineer_key
|
||||
current_auth = self.engineer_auth
|
||||
messages[0] = {"role": "system", "content": self.engineer_system_prompt}
|
||||
# Prepare handover context to inject AFTER all tool responses
|
||||
handover_msg = f"HANDOVER FROM ARCHITECT\n\nSummary: {args['summary']}\n\nYou are now back in control. Continue handling the user's requests."
|
||||
@@ -1258,7 +1290,7 @@ class ai:
|
||||
messages.append({"role": "user", "content": "Hard iteration limit reached. Please provide a summary of your findings so far."})
|
||||
try:
|
||||
safe_messages = self._sanitize_messages(messages)
|
||||
response = completion(model=model, messages=safe_messages, tools=[], api_key=key)
|
||||
response = completion(model=model, messages=safe_messages, tools=[], **current_auth)
|
||||
resp_msg = response.choices[0].message
|
||||
messages.append(resp_msg.model_dump(exclude_none=True))
|
||||
except Exception as e:
|
||||
@@ -1278,7 +1310,7 @@ class ai:
|
||||
try:
|
||||
safe_messages = self._sanitize_messages(summary_messages)
|
||||
# Use tools=None to force a text summary during interruption
|
||||
response = completion(model=model, messages=safe_messages, tools=None, api_key=key)
|
||||
response = completion(model=model, messages=safe_messages, tools=None, **current_auth)
|
||||
resp_msg = response.choices[0].message
|
||||
messages.append(resp_msg.model_dump(exclude_none=True))
|
||||
|
||||
@@ -1415,6 +1447,7 @@ Node: {node_name}"""
|
||||
# Use models based on persona
|
||||
current_model = self.architect_model if persona == "architect" else self.engineer_model
|
||||
current_key = self.architect_key if persona == "architect" else self.engineer_key
|
||||
current_auth = self.architect_auth if persona == "architect" else self.engineer_auth
|
||||
|
||||
try:
|
||||
while iteration < max_iterations:
|
||||
@@ -1424,8 +1457,8 @@ Node: {node_name}"""
|
||||
model=current_model,
|
||||
messages=messages,
|
||||
tools=mcp_tools if mcp_tools else None,
|
||||
api_key=current_key,
|
||||
stream=True
|
||||
stream=True,
|
||||
**current_auth
|
||||
)
|
||||
|
||||
full_content = ""
|
||||
@@ -1498,8 +1531,8 @@ Node: {node_name}"""
|
||||
model=self.engineer_model,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
api_key=self.engineer_key,
|
||||
stream=True
|
||||
stream=True,
|
||||
**self.engineer_auth
|
||||
)
|
||||
|
||||
full_content = ""
|
||||
|
||||
@@ -47,7 +47,7 @@ class AIHandler:
|
||||
# Determinar session_id para retomar
|
||||
session_id = None
|
||||
if args.resume:
|
||||
sessions = self.app.services.ai.list_sessions()
|
||||
sessions, _ = self.app.services.ai.list_sessions()
|
||||
session_id = sessions[0]["id"] if sessions else None
|
||||
if not session_id:
|
||||
printer.warning("No previous session found to resume.")
|
||||
@@ -66,15 +66,22 @@ class AIHandler:
|
||||
elif settings.get(key):
|
||||
arguments[key] = settings.get(key)
|
||||
|
||||
for key in ["engineer_auth", "architect_auth"]:
|
||||
cli_val = getattr(args, key, None)
|
||||
if cli_val:
|
||||
arguments[key] = self._parse_auth_value(cli_val[0])
|
||||
elif settings.get(key):
|
||||
arguments[key] = settings.get(key)
|
||||
|
||||
# Check keys only if running in local mode (not remote)
|
||||
if getattr(self.app.services, "mode", "local") == "local":
|
||||
if not arguments.get("engineer_api_key"):
|
||||
printer.error("Engineer API key not configured. The chat cannot start.")
|
||||
printer.info("Use 'connpy config --engineer-api-key <key>' to set it.")
|
||||
if not arguments.get("engineer_api_key") and not arguments.get("engineer_auth"):
|
||||
printer.error("Engineer API key/auth not configured. The chat cannot start.")
|
||||
printer.info("Use 'connpy config --engineer-api-key <key>' or 'connpy config --engineer-auth <auth>' to set it.")
|
||||
sys.exit(1)
|
||||
if not arguments.get("architect_api_key"):
|
||||
printer.warning("Architect API key not configured. Architect will be unavailable.")
|
||||
printer.info("Use 'connpy config --architect-api-key <key>' to enable it.")
|
||||
if not arguments.get("architect_api_key") and not arguments.get("architect_auth"):
|
||||
printer.warning("Architect API key/auth not configured. Architect will be unavailable.")
|
||||
printer.info("Use 'connpy config --architect-api-key <key>' or 'connpy config --architect-auth <auth>' to enable it.")
|
||||
|
||||
# El resto de la interacción el CLI la maneja con el agente subyacente
|
||||
self.app.myai = self.app.services.ai
|
||||
@@ -256,3 +263,33 @@ class AIHandler:
|
||||
|
||||
except Exception as e:
|
||||
printer.error(str(e))
|
||||
|
||||
def _parse_auth_value(self, value):
|
||||
if not value or value.lower() in ["none", "clear"]:
|
||||
return None
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
if os.path.exists(value):
|
||||
try:
|
||||
with open(value, "r") as f:
|
||||
content = f.read()
|
||||
try:
|
||||
return json.loads(content)
|
||||
except ValueError:
|
||||
return yaml.safe_load(content)
|
||||
except Exception as e:
|
||||
printer.error(f"Failed to read/parse auth file '{value}': {e}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
return json.loads(value)
|
||||
except ValueError:
|
||||
try:
|
||||
parsed = yaml.safe_load(value)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
raise ValueError()
|
||||
except Exception:
|
||||
printer.error("Auth parameter must be a valid JSON/YAML string, or a path to a JSON/YAML file.")
|
||||
sys.exit(1)
|
||||
|
||||
@@ -19,8 +19,10 @@ class ConfigHandler:
|
||||
"theme": self.set_theme,
|
||||
"engineer_model": self.set_ai_config,
|
||||
"engineer_api_key": self.set_ai_config,
|
||||
"engineer_auth": self.set_ai_config,
|
||||
"architect_model": self.set_ai_config,
|
||||
"architect_api_key": self.set_ai_config,
|
||||
"architect_auth": self.set_ai_config,
|
||||
"trusted_commands": self.set_ai_config,
|
||||
"service_mode": self.set_service_mode,
|
||||
"remote_host": self.set_remote_host,
|
||||
@@ -127,9 +129,57 @@ class ConfigHandler:
|
||||
try:
|
||||
settings = self.app.services.config_svc.get_settings()
|
||||
aiconfig = settings.get("ai", {})
|
||||
aiconfig[args.command] = args.data[0]
|
||||
val = args.data[0]
|
||||
|
||||
# Check for unset/clear request
|
||||
if val.lower() in ["none", "clear", ""]:
|
||||
if args.command in aiconfig:
|
||||
del aiconfig[args.command]
|
||||
else:
|
||||
# If configuring auth, parse as dictionary (JSON/YAML or file path)
|
||||
if args.command in ["engineer_auth", "architect_auth"]:
|
||||
parsed_val = self._parse_auth_value(val)
|
||||
if parsed_val is not None:
|
||||
aiconfig[args.command] = parsed_val
|
||||
else:
|
||||
if args.command in aiconfig:
|
||||
del aiconfig[args.command]
|
||||
else:
|
||||
aiconfig[args.command] = val
|
||||
|
||||
self.app.services.config_svc.update_setting("ai", aiconfig)
|
||||
printer.success("Config saved")
|
||||
except ConnpyError as e:
|
||||
except (ConnpyError, InvalidConfigurationError) as e:
|
||||
printer.error(str(e))
|
||||
|
||||
def _parse_auth_value(self, value):
|
||||
if value.lower() in ["none", "clear", ""]:
|
||||
return None
|
||||
|
||||
# Check if it's a file path
|
||||
import os
|
||||
if os.path.exists(value):
|
||||
try:
|
||||
with open(value, "r") as f:
|
||||
content = f.read()
|
||||
import json
|
||||
try:
|
||||
return json.loads(content)
|
||||
except ValueError:
|
||||
return yaml.safe_load(content)
|
||||
except Exception as e:
|
||||
raise InvalidConfigurationError(f"Failed to read/parse auth file '{value}': {e}")
|
||||
|
||||
# Try parsing as inline JSON/YAML
|
||||
try:
|
||||
import json
|
||||
return json.loads(value)
|
||||
except ValueError:
|
||||
try:
|
||||
parsed = yaml.safe_load(value)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
raise ValueError()
|
||||
except Exception:
|
||||
raise InvalidConfigurationError("Auth parameter must be a valid JSON/YAML string, or a path to a JSON/YAML file.")
|
||||
|
||||
|
||||
+18
-16
@@ -181,11 +181,28 @@ def _build_tree(nodes, folders, profiles, plugins, configdir):
|
||||
ai_dict = {"__exclude_used__": True, "--help": None, "-h": None}
|
||||
for opt in ["--engineer-model", "--engineer-api-key", "--architect-model", "--architect-api-key"]:
|
||||
ai_dict[opt] = {"*": ai_dict} # takes value, loops back
|
||||
ai_dict["--engineer-auth"] = {"__extra__": lambda w: get_cwd(w, "--engineer-auth"), "*": ai_dict}
|
||||
ai_dict["--architect-auth"] = {"__extra__": lambda w: get_cwd(w, "--architect-auth"), "*": ai_dict}
|
||||
for opt in ["--debug", "--trust", "--list", "--list-sessions", "--session", "--resume", "--delete", "--delete-session", "-y"]:
|
||||
ai_dict[opt] = ai_dict # takes no value, loops back
|
||||
ai_dict["--mcp"] = mcp_dict
|
||||
ai_dict["*"] = ai_dict
|
||||
|
||||
config_dict = {
|
||||
"--allow-uppercase": ["true", "false"],
|
||||
"--fzf": ["true", "false"],
|
||||
"--completion": ["bash", "zsh"],
|
||||
"--fzf-wrapper": ["bash", "zsh"],
|
||||
"--service-mode": ["local", "remote"],
|
||||
"--sync-remote": ["true", "false"],
|
||||
"--help": None, "-h": None,
|
||||
}
|
||||
for opt in ["--keepalive", "--engineer-model", "--engineer-api-key", "--architect-model", "--architect-api-key", "--theme", "--remote", "--trusted-commands"]:
|
||||
config_dict[opt] = {"*": config_dict}
|
||||
config_dict["--configfolder"] = {"__extra__": lambda w: get_cwd(w, "--configfolder", True), "*": config_dict}
|
||||
config_dict["--engineer-auth"] = {"__extra__": lambda w: get_cwd(w, "--engineer-auth"), "*": config_dict}
|
||||
config_dict["--architect-auth"] = {"__extra__": lambda w: get_cwd(w, "--architect-auth"), "*": config_dict}
|
||||
|
||||
mv_state = {"__extra__": _nodes, "--help": None, "-h": None}
|
||||
cp_state = {"__extra__": _nodes, "--help": None, "-h": None}
|
||||
ls_state = {
|
||||
@@ -280,22 +297,7 @@ def _build_tree(nodes, folders, profiles, plugins, configdir):
|
||||
"--list": None, "--help": None,
|
||||
"-h": None,
|
||||
},
|
||||
"config": {
|
||||
"--allow-uppercase": ["true", "false"],
|
||||
"--fzf": ["true", "false"],
|
||||
"--keepalive": None,
|
||||
"--completion": ["bash", "zsh"],
|
||||
"--fzf-wrapper": ["bash", "zsh"],
|
||||
"--configfolder": lambda w: get_cwd(w, "--configfolder", True),
|
||||
"--engineer-model": None, "--engineer-api-key": None,
|
||||
"--architect-model": None, "--architect-api-key": None,
|
||||
"--theme": None,
|
||||
"--service-mode": ["local", "remote"],
|
||||
"--remote": None,
|
||||
"--sync-remote": ["true", "false"],
|
||||
"--trusted-commands": None,
|
||||
"--help": None, "-h": None,
|
||||
},
|
||||
"config": config_dict,
|
||||
"sync": {
|
||||
"--login": None, "--logout": None,
|
||||
"--status": None, "--list": None,
|
||||
|
||||
@@ -276,8 +276,10 @@ class connapp:
|
||||
aiparser.add_argument("ask", nargs='*', help="Ask connpy AI something")
|
||||
aiparser.add_argument("--engineer-model", nargs=1, help="Override engineer model")
|
||||
aiparser.add_argument("--engineer-api-key", nargs=1, help="Override engineer api key")
|
||||
aiparser.add_argument("--engineer-auth", nargs=1, help="Override engineer auth (inline JSON/YAML or file path)")
|
||||
aiparser.add_argument("--architect-model", nargs=1, help="Override architect model")
|
||||
aiparser.add_argument("--architect-api-key", nargs=1, help="Override architect api key")
|
||||
aiparser.add_argument("--architect-auth", nargs=1, help="Override architect auth (inline JSON/YAML or file path)")
|
||||
aiparser.add_argument("--debug", action="store_true", help="Show AI reasoning and tool calls")
|
||||
aiparser.add_argument("-y", "--trust", action="store_true", help="Trust AI to execute unsafe commands without confirmation")
|
||||
aiparser.add_argument("--list", "--list-sessions", dest="list_sessions", action="store_true", help="List saved AI sessions")
|
||||
@@ -341,11 +343,13 @@ class connapp:
|
||||
configcrud.add_argument("--configfolder", dest="configfolder", nargs=1, action=self._store_type, help="Set the default location for config file", metavar="FOLDER")
|
||||
configcrud.add_argument("--engineer-model", dest="engineer_model", nargs=1, action=self._store_type, help="Set engineer model", metavar="MODEL")
|
||||
configcrud.add_argument("--engineer-api-key", dest="engineer_api_key", nargs=1, action=self._store_type, help="Set engineer api_key", metavar="API_KEY")
|
||||
configcrud.add_argument("--engineer-auth", dest="engineer_auth", nargs=1, action=self._store_type, help="Set engineer auth (inline JSON/YAML or file path)", metavar="AUTH")
|
||||
configcrud.add_argument("--theme", dest="theme", nargs=1, action=self._store_type, help="Set application theme (dark, light, or YAML file path)", metavar="THEME")
|
||||
configcrud.add_argument("--service-mode", dest="service_mode", nargs=1, action=self._store_type, help="Set the backend service mode (local or remote)", choices=["local", "remote"])
|
||||
configcrud.add_argument("--remote", dest="remote_host", nargs=1, action=self._store_type, help="Connect to a remote connpy service via gRPC", metavar="HOST:PORT")
|
||||
configcrud.add_argument("--architect-model", dest="architect_model", nargs=1, action=self._store_type, help="Set architect model", metavar="MODEL")
|
||||
configcrud.add_argument("--architect-api-key", dest="architect_api_key", nargs=1, action=self._store_type, help="Set architect api_key", metavar="API_KEY")
|
||||
configcrud.add_argument("--architect-auth", dest="architect_auth", nargs=1, action=self._store_type, help="Set architect auth (inline JSON/YAML or file path)", metavar="AUTH")
|
||||
configcrud.add_argument("--sync-remote", dest="sync_remote", nargs=1, action=self._store_type, help="Sync remote nodes to Google Drive", choices=["true","false"])
|
||||
configparser.add_argument("--trusted-commands", dest="trusted_commands", nargs=1, action=self._store_type, help="Set custom trusted commands regexes (comma separated)", metavar="REGEX,REGEX")
|
||||
configparser.set_defaults(func=self._config.dispatch)
|
||||
|
||||
+4
-9
@@ -439,21 +439,16 @@ class node:
|
||||
# Remove any stray \x00 bytes and forward normally
|
||||
clean_data = data.replace(b'\x00', b'')
|
||||
if clean_data:
|
||||
# Track command boundaries when user hits Enter
|
||||
if hasattr(self, 'mylog') and (b'\r' in clean_data or b'\n' in clean_data):
|
||||
# Introduce a tiny 20ms delay to allow late-arriving tab-completion bytes
|
||||
# to be written to mylog before finalizing the boundary marker.
|
||||
async def delayed_marker():
|
||||
await asyncio.sleep(0.02)
|
||||
if hasattr(self, 'mylog'):
|
||||
# Track command boundaries when user hits Enter or presses Ctrl+C
|
||||
if hasattr(self, 'mylog') and (b'\r' in clean_data or b'\n' in clean_data or b'\x03' in clean_data):
|
||||
pos = self.mylog.tell()
|
||||
self.cmd_byte_positions.append((pos, None))
|
||||
marker_cmd = "CANCELLED" if b'\x03' in clean_data else None
|
||||
self.cmd_byte_positions.append((pos, marker_cmd))
|
||||
if hasattr(self, 'current_local_stream') and self.current_local_stream is not None:
|
||||
try:
|
||||
await self.current_local_stream.write(f'\x1b]133;B;{pos}\x07'.encode())
|
||||
except Exception:
|
||||
pass
|
||||
asyncio.create_task(delayed_marker())
|
||||
|
||||
try:
|
||||
os.write(child_fd, clean_data)
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -3,7 +3,7 @@
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from . import connpy_pb2 as connpy__pb2
|
||||
import connpy_pb2 as connpy__pb2
|
||||
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.80.0'
|
||||
|
||||
@@ -893,6 +893,10 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
|
||||
overrides = {}
|
||||
if req.engineer_model: overrides["engineer_model"] = req.engineer_model
|
||||
if req.engineer_api_key: overrides["engineer_api_key"] = req.engineer_api_key
|
||||
if req.architect_model: overrides["architect_model"] = req.architect_model
|
||||
if req.architect_api_key: overrides["architect_api_key"] = req.architect_api_key
|
||||
if req.HasField("engineer_auth"): overrides["engineer_auth"] = from_struct(req.engineer_auth)
|
||||
if req.HasField("architect_auth"): overrides["architect_auth"] = from_struct(req.architect_auth)
|
||||
|
||||
# Start AI in its own thread so we can keep listening for interrupts
|
||||
ai_thread = threading.Thread(
|
||||
@@ -967,7 +971,8 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
|
||||
|
||||
@handle_errors
|
||||
def configure_provider(self, request, context):
|
||||
self.service.configure_provider(request.provider, request.model, request.api_key)
|
||||
auth_dict = from_struct(request.auth) if request.HasField("auth") else None
|
||||
self.service.configure_provider(request.provider, request.model, request.api_key, auth=auth_dict)
|
||||
return Empty()
|
||||
|
||||
@handle_errors
|
||||
|
||||
@@ -745,6 +745,10 @@ class AIStub:
|
||||
)
|
||||
if chat_history is not None:
|
||||
initial_req.chat_history.CopyFrom(to_value(chat_history))
|
||||
if "engineer_auth" in overrides and overrides["engineer_auth"]:
|
||||
initial_req.engineer_auth.CopyFrom(to_struct(overrides["engineer_auth"]))
|
||||
if "architect_auth" in overrides and overrides["architect_auth"]:
|
||||
initial_req.architect_auth.CopyFrom(to_struct(overrides["architect_auth"]))
|
||||
|
||||
req_queue.put(initial_req)
|
||||
|
||||
@@ -926,8 +930,10 @@ class AIStub:
|
||||
self.stub.delete_session(connpy_pb2.StringRequest(value=session_id))
|
||||
|
||||
@handle_errors
|
||||
def configure_provider(self, provider, model=None, api_key=None):
|
||||
def configure_provider(self, provider, model=None, api_key=None, auth=None):
|
||||
req = connpy_pb2.ProviderRequest(provider=provider, model=model or "", api_key=api_key or "")
|
||||
if auth:
|
||||
req.auth.CopyFrom(to_struct(auth))
|
||||
self.stub.configure_provider(req)
|
||||
|
||||
@handle_errors
|
||||
|
||||
@@ -235,6 +235,8 @@ message AskRequest {
|
||||
bool trust = 10;
|
||||
string confirmation_answer = 11;
|
||||
bool interrupt = 12;
|
||||
google.protobuf.Struct engineer_auth = 13;
|
||||
google.protobuf.Struct architect_auth = 14;
|
||||
}
|
||||
|
||||
message AIResponse {
|
||||
@@ -255,6 +257,7 @@ message ProviderRequest {
|
||||
string provider = 1;
|
||||
string model = 2;
|
||||
string api_key = 3;
|
||||
google.protobuf.Struct auth = 4;
|
||||
}
|
||||
|
||||
message IntRequest {
|
||||
|
||||
@@ -58,6 +58,9 @@ class AIService(BaseService):
|
||||
prev_pos = cmd_byte_positions[i-1][0]
|
||||
|
||||
if known_cmd:
|
||||
if known_cmd == "CANCELLED":
|
||||
parsed_positions.append({"pos": pos, "type": "CANCELLED", "preview": ""})
|
||||
else:
|
||||
prev_chunk = raw_bytes[prev_pos:pos]
|
||||
prev_cleaned = self._clean_cisco_scrolling(prev_chunk.decode(errors='replace'))
|
||||
prev_lines = [l for l in prev_cleaned.split('\n') if l.strip()]
|
||||
@@ -129,11 +132,11 @@ class AIService(BaseService):
|
||||
start_pos = item["pos"]
|
||||
preview = item["preview"]
|
||||
|
||||
# Find the end position: next VALID_CMD or EMPTY_PROMPT
|
||||
# Find the end position: next VALID_CMD or EMPTY_PROMPT or CANCELLED
|
||||
end_pos = current_prompt_pos
|
||||
for j in range(i + 1, len(parsed_positions)):
|
||||
next_item = parsed_positions[j]
|
||||
if next_item["type"] in ("VALID_CMD", "EMPTY_PROMPT"):
|
||||
if next_item["type"] in ("VALID_CMD", "EMPTY_PROMPT", "CANCELLED"):
|
||||
end_pos = next_item["pos"]
|
||||
break
|
||||
|
||||
@@ -254,13 +257,15 @@ class AIService(BaseService):
|
||||
else:
|
||||
raise InvalidConfigurationError(f"Session '{session_id}' not found.")
|
||||
|
||||
def configure_provider(self, provider, model=None, api_key=None):
|
||||
def configure_provider(self, provider, model=None, api_key=None, auth=None):
|
||||
"""Update AI provider settings in the configuration."""
|
||||
settings = self.config.config.get("ai", {})
|
||||
if model:
|
||||
settings[f"{provider}_model"] = model
|
||||
if api_key:
|
||||
settings[f"{provider}_api_key"] = api_key
|
||||
if auth is not None:
|
||||
settings[f"{provider}_auth"] = auth
|
||||
|
||||
self.config.config["ai"] = settings
|
||||
self.config._saveconfig(self.config.file)
|
||||
|
||||
+76
-3
@@ -23,7 +23,7 @@ class TestAIInit:
|
||||
myai = ai(config)
|
||||
with pytest.raises(ValueError) as exc:
|
||||
myai.ask("hello")
|
||||
assert "Engineer API key not configured" in str(exc.value)
|
||||
assert "Engineer API key or authentication not configured" in str(exc.value)
|
||||
|
||||
def test_init_missing_architect_key_warns(self, ai_config, capsys, mock_litellm):
|
||||
"""Warns if architect key is missing but doesn't crash."""
|
||||
@@ -58,6 +58,77 @@ class TestAIInit:
|
||||
pass # May fail on other file opens, that's ok
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# AI Auth Dict tests
|
||||
# =========================================================================
|
||||
|
||||
class TestAIAuthDict:
|
||||
def test_init_with_auth_dict(self, ai_config):
|
||||
"""Initializes correctly when auth dicts are configured."""
|
||||
from connpy.ai import ai
|
||||
ai_config.config["ai"]["engineer_api_key"] = None
|
||||
ai_config.config["ai"]["architect_api_key"] = None
|
||||
ai_config.config["ai"]["engineer_auth"] = {"my_key": "my_val"}
|
||||
ai_config.config["ai"]["architect_auth"] = {"another_key": "another_val"}
|
||||
myai = ai(ai_config)
|
||||
assert myai.engineer_auth == {"my_key": "my_val"}
|
||||
assert myai.architect_auth == {"another_key": "another_val"}
|
||||
|
||||
def test_compat_key_injection(self, ai_config):
|
||||
"""Injects API key into auth dict if auth is empty or doesn't have it."""
|
||||
from connpy.ai import ai
|
||||
ai_config.config["ai"]["engineer_api_key"] = "compat-eng-key"
|
||||
ai_config.config["ai"]["architect_api_key"] = "compat-arch-key"
|
||||
ai_config.config["ai"]["engineer_auth"] = {}
|
||||
ai_config.config["ai"]["architect_auth"] = {}
|
||||
myai = ai(ai_config)
|
||||
assert myai.engineer_auth == {"api_key": "compat-eng-key"}
|
||||
assert myai.architect_auth == {"api_key": "compat-arch-key"}
|
||||
|
||||
def test_has_architect_keyless(self, ai_config):
|
||||
"""Evaluates has_architect correctly for keyless models and auth configs."""
|
||||
from connpy.ai import ai
|
||||
# 1. Keyless model (Vertex)
|
||||
ai_config.config["ai"]["architect_api_key"] = None
|
||||
ai_config.config["ai"]["architect_auth"] = {}
|
||||
ai_config.config["ai"]["architect_model"] = "vertex/gemini-pro"
|
||||
myai = ai(ai_config)
|
||||
assert myai.has_architect is True
|
||||
|
||||
# 2. Architect auth dict is set
|
||||
ai_config.config["ai"]["architect_model"] = "custom-model"
|
||||
ai_config.config["ai"]["architect_auth"] = {"vertex_project": "proj-1"}
|
||||
myai = ai(ai_config)
|
||||
assert myai.has_architect is True
|
||||
|
||||
def test_ask_unpacks_auth_dict(self, ai_config, mock_litellm):
|
||||
"""Verifies that ask unpacks engineer_auth when calling completion."""
|
||||
from connpy.ai import ai
|
||||
ai_config.config["ai"]["engineer_api_key"] = None
|
||||
ai_config.config["ai"]["engineer_auth"] = {"vertex_project": "my-project", "vertex_location": "us-east1"}
|
||||
myai = ai(ai_config)
|
||||
myai.ask("test query", stream=False)
|
||||
# Check mock_litellm completion call
|
||||
mock_litellm["completion"].assert_called()
|
||||
kwargs = mock_litellm["completion"].call_args.kwargs
|
||||
assert kwargs.get("vertex_project") == "my-project"
|
||||
assert kwargs.get("vertex_location") == "us-east1"
|
||||
assert "api_key" not in kwargs
|
||||
|
||||
def test_auth_precedence_no_api_key_injection(self, ai_config):
|
||||
"""Verifies that api_key is not injected into the auth dict when auth is already set (non-empty)."""
|
||||
from connpy.ai import ai
|
||||
ai_config.config["ai"]["engineer_api_key"] = "legacy-eng-key"
|
||||
ai_config.config["ai"]["architect_api_key"] = "legacy-arch-key"
|
||||
ai_config.config["ai"]["engineer_auth"] = {"vertex_project": "proj-eng"}
|
||||
ai_config.config["ai"]["architect_auth"] = {"vertex_project": "proj-arch"}
|
||||
myai = ai(ai_config)
|
||||
assert myai.engineer_auth == {"vertex_project": "proj-eng"}
|
||||
assert "api_key" not in myai.engineer_auth
|
||||
assert myai.architect_auth == {"vertex_project": "proj-arch"}
|
||||
assert "api_key" not in myai.architect_auth
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# register_ai_tool tests
|
||||
# =========================================================================
|
||||
@@ -427,12 +498,14 @@ class TestAISessions:
|
||||
|
||||
def test_generate_session_id(self, myai):
|
||||
session_id = myai._generate_session_id("Any query")
|
||||
# Format: YYYYMMDD-HHMMSS
|
||||
assert len(session_id) == 15
|
||||
# Format: YYYYMMDD-HHMMSS-suffix
|
||||
assert len(session_id) == 20
|
||||
assert "-" in session_id
|
||||
parts = session_id.split("-")
|
||||
assert len(parts) == 3
|
||||
assert len(parts[0]) == 8 # YYYYMMDD
|
||||
assert len(parts[1]) == 6 # HHMMSS
|
||||
assert len(parts[2]) == 4 # suffix
|
||||
|
||||
def test_save_and_load_session(self, myai):
|
||||
history = [
|
||||
|
||||
@@ -193,3 +193,28 @@ def test_build_context_blocks_horizontal_scrolling_ansi():
|
||||
assert len(blocks) >= 1
|
||||
start, end, preview = blocks[0]
|
||||
assert "RP/0/RP0/CPU0:xrd# s show interfaces * | inc" in preview
|
||||
|
||||
|
||||
def test_build_context_blocks_cancelled_command():
|
||||
from connpy.services.ai_service import AIService
|
||||
svc = AIService(None)
|
||||
|
||||
node_info = {"prompt": "router#"}
|
||||
# Command 1: cancelled with Ctrl+C. Command 2: executed successfully.
|
||||
raw_bytes = b"router# show plat\x03\r\nrouter# show ver\r\nrouter# "
|
||||
|
||||
# 0: initial boundary
|
||||
# 18: Ctrl+C pressed (ends Command 1, marked CANCELLED)
|
||||
# 36: Enter pressed (ends Command 2)
|
||||
cmd_byte_positions = [(0, None), (18, "CANCELLED"), (36, None)]
|
||||
|
||||
blocks = svc.build_context_blocks(raw_bytes, cmd_byte_positions, node_info)
|
||||
|
||||
# The cancelled command block (0 to 18) should NOT be registered as a VALID_CMD block.
|
||||
# The block for "show ver" should be registered (starting at 36, ending at current_prompt_pos).
|
||||
# Plus, the final block for "CURRENT CONTEXT".
|
||||
valid_blocks = [b for b in blocks if "CURRENT CONTEXT" not in b[2]]
|
||||
assert len(valid_blocks) == 1
|
||||
assert "show ver" in valid_blocks[0][2]
|
||||
assert "show plat" not in valid_blocks[0][2]
|
||||
|
||||
|
||||
@@ -65,4 +65,80 @@ class TestGetCwd:
|
||||
assert len(dirs_in_result) > 0
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tree completions tests
|
||||
# =========================================================================
|
||||
|
||||
class TestTreeCompletions:
|
||||
def test_config_auth_completions(self):
|
||||
from connpy.completion import _build_tree, resolve_completion
|
||||
tree = _build_tree([], [], [], {}, "/tmp")
|
||||
# Test config completions
|
||||
config_completions = resolve_completion(["config", ""], tree)
|
||||
assert "--engineer-auth" in config_completions
|
||||
assert "--architect-auth" in config_completions
|
||||
|
||||
# Resolve when --engineer-auth is chosen in config
|
||||
auth_comp = resolve_completion(["config", "--engineer-auth", ""], tree)
|
||||
assert isinstance(auth_comp, list)
|
||||
|
||||
# Loop back check:
|
||||
# e.g., connpy config --engineer-auth some_val
|
||||
# should loop back and resolve to config options
|
||||
loop_back_comp = resolve_completion(["config", "--engineer-auth", "some_val", ""], tree)
|
||||
assert "--architect-auth" in loop_back_comp
|
||||
assert "--engineer-auth" in loop_back_comp
|
||||
|
||||
def test_ai_auth_completions(self):
|
||||
from connpy.completion import _build_tree, resolve_completion
|
||||
tree = _build_tree([], [], [], {}, "/tmp")
|
||||
# Test ai completions
|
||||
ai_completions = resolve_completion(["ai", ""], tree)
|
||||
assert "--engineer-auth" in ai_completions
|
||||
assert "--architect-auth" in ai_completions
|
||||
|
||||
# Resolve after choosing option
|
||||
auth_comp = resolve_completion(["ai", "--engineer-auth", ""], tree)
|
||||
assert isinstance(auth_comp, list)
|
||||
|
||||
# Loop back check:
|
||||
# e.g., connpy ai --engineer-auth some_val
|
||||
# should loop back and resolve to ai options, excluding --engineer-auth
|
||||
loop_back_comp = resolve_completion(["ai", "--engineer-auth", "some_val", ""], tree)
|
||||
assert "--architect-auth" in loop_back_comp
|
||||
assert "--engineer-auth" not in loop_back_comp
|
||||
|
||||
def test_sixwindmcp_plugin_completions(self):
|
||||
from connpy.completion import resolve_completion, get_cwd
|
||||
import importlib.util
|
||||
|
||||
# Load the testremote/remote_plugins/sixwindmcp.py plugin
|
||||
plugin_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
"testremote", "remote_plugins", "sixwindmcp.py"
|
||||
)
|
||||
spec = importlib.util.spec_from_file_location("sixwindmcp", plugin_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
module.get_cwd = get_cwd
|
||||
|
||||
plugin_node = module._connpy_tree()
|
||||
assert "--set-path" in plugin_node
|
||||
assert "--path" in plugin_node
|
||||
assert "start" in plugin_node
|
||||
|
||||
tree = {"sixwindmcp": plugin_node}
|
||||
|
||||
# Test resolution when --set-path is chosen
|
||||
res = resolve_completion(["sixwindmcp", "--set-path", ""], tree)
|
||||
assert isinstance(res, list)
|
||||
|
||||
# Loop back check:
|
||||
# e.g., connpy sixwindmcp --set-path /tmp start
|
||||
# should loop back and resolve to plugin options
|
||||
loop_back_comp = resolve_completion(["sixwindmcp", "--set-path", "/tmp", ""], tree)
|
||||
assert "start" in loop_back_comp
|
||||
assert "stop" in loop_back_comp
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -246,7 +246,7 @@ def test_plugin_disable(mock_disable, app):
|
||||
|
||||
@patch("connpy.services.ai_service.AIService.list_sessions")
|
||||
def test_ai_list(mock_list_sessions, app):
|
||||
mock_list_sessions.return_value = [{"id": "1", "title": "t", "created_at": "now", "model": "m"}]
|
||||
mock_list_sessions.return_value = ([{"id": "1", "title": "t", "created_at": "now", "model": "m"}], 1)
|
||||
app.start(["ai", "--list"])
|
||||
mock_list_sessions.assert_called_once()
|
||||
|
||||
@@ -262,3 +262,55 @@ def test_type_node_reserved_word(app):
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
app._type_node("bulk")
|
||||
assert exc.value.code == 2
|
||||
|
||||
@patch("connpy.services.config_service.ConfigService.update_setting")
|
||||
@patch("connpy.services.config_service.ConfigService.get_settings")
|
||||
def test_config_auth_inline_json(mock_get_settings, mock_update_setting, app):
|
||||
mock_get_settings.return_value = {"ai": {}}
|
||||
app.start(["config", "--engineer-auth", '{"vertex_project": "test-123"}'])
|
||||
mock_update_setting.assert_called_once()
|
||||
args, kwargs = mock_update_setting.call_args
|
||||
assert args[0] == "ai"
|
||||
assert args[1]["engineer_auth"] == {"vertex_project": "test-123"}
|
||||
|
||||
@patch("connpy.services.config_service.ConfigService.update_setting")
|
||||
@patch("connpy.services.config_service.ConfigService.get_settings")
|
||||
def test_config_auth_inline_yaml(mock_get_settings, mock_update_setting, app):
|
||||
mock_get_settings.return_value = {"ai": {}}
|
||||
app.start(["config", "--architect-auth", 'project: test-yaml'])
|
||||
mock_update_setting.assert_called_once()
|
||||
args, kwargs = mock_update_setting.call_args
|
||||
assert args[0] == "ai"
|
||||
assert args[1]["architect_auth"] == {"project": "test-yaml"}
|
||||
|
||||
@patch("connpy.services.config_service.ConfigService.update_setting")
|
||||
@patch("connpy.services.config_service.ConfigService.get_settings")
|
||||
def test_config_clear_auth(mock_get_settings, mock_update_setting, app):
|
||||
mock_get_settings.return_value = {"ai": {"engineer_auth": {"project": "123"}, "engineer_api_key": "some-key"}}
|
||||
|
||||
app.start(["config", "--engineer-auth", "clear"])
|
||||
args, kwargs = mock_update_setting.call_args
|
||||
assert "engineer_auth" not in args[1]
|
||||
|
||||
app.start(["config", "--engineer-api-key", "none"])
|
||||
args, kwargs = mock_update_setting.call_args
|
||||
assert "engineer_api_key" not in args[1]
|
||||
|
||||
@patch("os.path.exists")
|
||||
@patch("builtins.open")
|
||||
@patch("connpy.services.config_service.ConfigService.update_setting")
|
||||
@patch("connpy.services.config_service.ConfigService.get_settings")
|
||||
def test_config_auth_file_path(mock_get_settings, mock_update_setting, mock_open, mock_exists, app):
|
||||
mock_get_settings.return_value = {"ai": {}}
|
||||
mock_exists.side_effect = lambda p: True if p == "/path/to/creds.json" else False
|
||||
mock_file = MagicMock()
|
||||
mock_file.read.return_value = '{"vertex_project": "file-project"}'
|
||||
mock_open.return_value.__enter__.return_value = mock_file
|
||||
|
||||
app.start(["config", "--engineer-auth", "/path/to/creds.json"])
|
||||
mock_update_setting.assert_called_once()
|
||||
args, kwargs = mock_update_setting.call_args
|
||||
assert args[0] == "ai"
|
||||
assert args[1]["engineer_auth"] == {"vertex_project": "file-project"}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Tests for gRPC auth serialization/deserialization (engineer_auth, architect_auth, provider auth).
|
||||
|
||||
These tests verify that:
|
||||
1. to_struct/from_struct round-trips correctly for auth dicts.
|
||||
2. AIStub.ask() correctly serializes engineer_auth and architect_auth into AskRequest.
|
||||
3. AIServicer.ask() correctly deserializes them and passes them to the service.
|
||||
4. AIStub.configure_provider() serializes auth into ProviderRequest.
|
||||
5. AIServicer.configure_provider() deserializes auth and forwards it to the service.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
from connpy.grpc_layer import connpy_pb2
|
||||
from connpy.grpc_layer.utils import to_struct, from_struct
|
||||
|
||||
|
||||
# --- Unit: Struct round-trip ---
|
||||
|
||||
class TestStructRoundTrip:
|
||||
def test_simple_dict(self):
|
||||
d = {"api_key": "secret", "region": "us-east-1"}
|
||||
assert from_struct(to_struct(d)) == d
|
||||
|
||||
def test_nested_dict(self):
|
||||
d = {"vertex_project": "my-project", "vertex_location": "us-central1", "nested": {"key": "val"}}
|
||||
assert from_struct(to_struct(d)) == d
|
||||
|
||||
def test_empty_dict(self):
|
||||
assert from_struct(to_struct({})) == {}
|
||||
|
||||
def test_none_returns_empty(self):
|
||||
assert from_struct(to_struct(None)) == {}
|
||||
|
||||
|
||||
# --- Unit: AskRequest Struct fields ---
|
||||
|
||||
class TestAskRequestStructFields:
|
||||
def test_engineer_auth_round_trip(self):
|
||||
auth = {"vertex_project": "proj", "vertex_location": "us-central1"}
|
||||
req = connpy_pb2.AskRequest(input_text="hi")
|
||||
req.engineer_auth.CopyFrom(to_struct(auth))
|
||||
assert from_struct(req.engineer_auth) == auth
|
||||
|
||||
def test_architect_auth_round_trip(self):
|
||||
auth = {"api_key": "sk-abc", "base_url": "https://custom.api/v1"}
|
||||
req = connpy_pb2.AskRequest(input_text="hi")
|
||||
req.architect_auth.CopyFrom(to_struct(auth))
|
||||
assert from_struct(req.architect_auth) == auth
|
||||
|
||||
def test_has_field_false_when_unset(self):
|
||||
req = connpy_pb2.AskRequest(input_text="hi")
|
||||
assert not req.HasField("engineer_auth")
|
||||
assert not req.HasField("architect_auth")
|
||||
|
||||
def test_has_field_true_when_set(self):
|
||||
req = connpy_pb2.AskRequest(input_text="hi")
|
||||
req.engineer_auth.CopyFrom(to_struct({"k": "v"}))
|
||||
assert req.HasField("engineer_auth")
|
||||
|
||||
|
||||
# --- Unit: ProviderRequest Struct field ---
|
||||
|
||||
class TestProviderRequestStructField:
|
||||
def test_auth_round_trip(self):
|
||||
auth = {"vertex_project": "proj", "vertex_location": "eu-west1"}
|
||||
req = connpy_pb2.ProviderRequest(provider="vertex", model="gemini-pro")
|
||||
req.auth.CopyFrom(to_struct(auth))
|
||||
assert from_struct(req.auth) == auth
|
||||
|
||||
def test_has_field_false_when_unset(self):
|
||||
req = connpy_pb2.ProviderRequest(provider="openai", model="gpt-4o")
|
||||
assert not req.HasField("auth")
|
||||
|
||||
def test_has_field_true_when_set(self):
|
||||
req = connpy_pb2.ProviderRequest(provider="vertex")
|
||||
req.auth.CopyFrom(to_struct({"vertex_project": "p"}))
|
||||
assert req.HasField("auth")
|
||||
|
||||
|
||||
# --- Integration: Server deserializes auth and passes to service ---
|
||||
|
||||
class TestAIServicerAuthDeserialization:
|
||||
@pytest.fixture
|
||||
def servicer(self, populated_config):
|
||||
from connpy.grpc_layer.server import AIServicer
|
||||
return AIServicer(populated_config)
|
||||
|
||||
def test_configure_provider_passes_auth_to_service(self, servicer):
|
||||
auth = {"vertex_project": "my-proj", "vertex_location": "us-central1"}
|
||||
req = connpy_pb2.ProviderRequest(provider="vertex", model="gemini/gemini-pro", api_key="")
|
||||
req.auth.CopyFrom(to_struct(auth))
|
||||
|
||||
with patch.object(servicer.service, "configure_provider") as mock_cp:
|
||||
mock_context = MagicMock()
|
||||
servicer.configure_provider(req, mock_context)
|
||||
mock_cp.assert_called_once_with("vertex", "gemini/gemini-pro", "", auth=auth)
|
||||
|
||||
def test_configure_provider_no_auth(self, servicer):
|
||||
req = connpy_pb2.ProviderRequest(provider="openai", model="gpt-4o", api_key="sk-test")
|
||||
|
||||
with patch.object(servicer.service, "configure_provider") as mock_cp:
|
||||
mock_context = MagicMock()
|
||||
servicer.configure_provider(req, mock_context)
|
||||
mock_cp.assert_called_once_with("openai", "gpt-4o", "sk-test", auth=None)
|
||||
|
||||
|
||||
# --- Integration: Stub serializes auth into request ---
|
||||
|
||||
class TestAIStubAuthSerialization:
|
||||
@pytest.fixture
|
||||
def ai_stub(self):
|
||||
from connpy.grpc_layer.stubs import AIStub
|
||||
mock_channel = MagicMock()
|
||||
stub = AIStub(mock_channel, "localhost:8048")
|
||||
return stub
|
||||
|
||||
def test_configure_provider_with_auth_serializes_struct(self, ai_stub):
|
||||
auth = {"vertex_project": "proj", "vertex_location": "us-central1"}
|
||||
ai_stub.stub.configure_provider = MagicMock()
|
||||
|
||||
ai_stub.configure_provider("vertex", model="gemini/gemini-pro", auth=auth)
|
||||
|
||||
ai_stub.stub.configure_provider.assert_called_once()
|
||||
sent_req = ai_stub.stub.configure_provider.call_args[0][0]
|
||||
assert sent_req.provider == "vertex"
|
||||
assert sent_req.model == "gemini/gemini-pro"
|
||||
assert sent_req.HasField("auth")
|
||||
assert from_struct(sent_req.auth) == auth
|
||||
|
||||
def test_configure_provider_without_auth_no_struct(self, ai_stub):
|
||||
ai_stub.stub.configure_provider = MagicMock()
|
||||
|
||||
ai_stub.configure_provider("openai", model="gpt-4o", api_key="sk-x")
|
||||
|
||||
sent_req = ai_stub.stub.configure_provider.call_args[0][0]
|
||||
assert not sent_req.HasField("auth")
|
||||
Reference in New Issue
Block a user