sanitise --init inputs

Adds function to check most inputs for config file, to ensure they meet requirements.

For example, that tags don't have a "#" and the username doesn't have a "@"
This commit is contained in:
Hugh Rundle 2020-09-12 14:04:53 +10:00
parent 0bd6831f44
commit e8ff05b11d
2 changed files with 199 additions and 3 deletions

View File

@ -2,6 +2,7 @@
from datetime import date, datetime, timedelta, timezone from datetime import date, datetime, timedelta, timezone
import json import json
import os import os
import re
import subprocess import subprocess
import sys import sys
import time import time
@ -29,7 +30,14 @@ def compulsory_input(tags, name, example):
else: else:
value = input(tags[0] + name + tags[2]) value = input(tags[0] + name + tags[2])
return value sanitised = sanitise_input(value, name, tags)
if len(value) > 0 and (sanitised == "ok" or sanitised == None):
return value
else:
if len(value) > 0 and sanitised != None:
print(sanitised)
value = ""
def digit_input(tags, name, example): def digit_input(tags, name, example):
@ -53,8 +61,108 @@ def yes_no_input(tags, name):
def optional_input(tags, name, example): def optional_input(tags, name, example):
value = input(tags[0] + name + tags[1] + example + tags[2])
return value incomplete = True
while incomplete:
value = input(tags[0] + name + tags[1] + example + tags[2])
sanitised = sanitise_input(value, name, tags)
if len(value) > 0 and (sanitised == "ok" or sanitised == None):
incomplete = False
return value
elif len(value) > 0 and sanitised != None:
print(sanitised)
else:
return ""
def sanitise_input(value, input_type, tags):
"""
Check that data entered when running --init complies with requirements
"""
if input_type == "Username":
return (
"Do not include '@' in username, please try again"
if value.startswith("@")
else "ok"
)
if input_type == "Base URL":
error = value.startswith("http") or value.find(".") == -1
return (
"Provide full domain without protocol prefix (e.g. "
+ tags[1]
+ "example.social"
+ tags[2]
+ ", not "
+ tags[1]
+ "http://example.social"
+ tags[2]
+ ")"
if error
else "ok"
)
if input_type == "Toots to keep":
l = value.split(",")
def check(s):
d = s.strip()
if not d.isdigit():
return False
allnum = map(check, l)
return (
"Toot IDs must be numeric and separated with commas"
if False in list(allnum)
else "ok"
)
if input_type == "Hashtags to keep":
l = value.split(",")
def check(s):
d = s.strip()
if d.isdigit():
return False
if not re.fullmatch(r"[\w]+", d, flags=re.IGNORECASE):
return False
complies = map(check, l)
return_string = (
"Hashtags must not include '#' and must match rules at "
+ tags[0]
+ "https://docs.joinmastodon.org/user/posting/#hashtags"
+ tags[2]
)
return return_string if False in list(complies) else "ok"
if input_type == "Visibility to keep":
l = value.split(",")
viz_options = set(["public", "unlisted", "private", "direct"])
def check(s):
d = [s.strip().lower()]
intersects = viz_options.intersection(d)
if len(intersects) == 0:
return False
complies = map(check, l)
return_string = "Valid values are one or more of 'public', 'unlisted', 'private' or 'direct'"
return return_string if False in list(complies) else "ok"
if input_type == "Archive path":
path = (
os.path.expanduser(value)
if len(str(value)) > 0 and str(value)[0] == "~"
else value
)
response = (
"ok"
if os.path.exists(path)
else "That directory does not exist, please try again"
)
return response
def init(): def init():

View File

@ -291,6 +291,94 @@ def test_init(monkeypatch, tmpdir):
assert os.path.exists(os.path.join(current_dir, "config.yaml")) assert os.path.exists(os.path.join(current_dir, "config.yaml"))
def test_init_archive_path(tmpdir):
good_path = tmpdir.mkdir("archive_dir") # temporary directory for testing
wrong = ephemetoot.sanitise_input(
os.path.join(good_path, "/bad/path/"), "Archive path", None
)
ok = ephemetoot.sanitise_input(good_path, "Archive path", None)
also_ok = ephemetoot.sanitise_input("~/Desktop", "Archive path", None)
assert ok == "ok"
assert wrong == "That directory does not exist, please try again"
def test_init_sanitise_id_list():
tags = ("\033[96m", "\033[2m", "\033[0m")
wrong = ephemetoot.sanitise_input(
"987654321, toot_id_number", "Toots to keep", tags
)
also_wrong = ephemetoot.sanitise_input("toot_id_number", "Toots to keep", tags)
ok = ephemetoot.sanitise_input("1234598745, 999933335555", "Toots to keep", tags)
also_ok = ephemetoot.sanitise_input("1234598745", "Toots to keep", tags)
assert wrong == "Toot IDs must be numeric and separated with commas"
assert also_wrong == "Toot IDs must be numeric and separated with commas"
assert ok == "ok"
assert also_ok == "ok"
def test_init_sanitise_tag_list():
tags = ("\033[96m", "\033[2m", "\033[0m")
wrong = ephemetoot.sanitise_input("#tag, another_tag", "Hashtags to keep", tags)
also_wrong = ephemetoot.sanitise_input("tag, another tag", "Hashtags to keep", tags)
still_wrong = ephemetoot.sanitise_input("tag, 12345", "Hashtags to keep", tags)
ok = ephemetoot.sanitise_input("tag123, another_TAG", "Hashtags to keep", tags)
also_ok = ephemetoot.sanitise_input("single_tag", "Hashtags to keep", tags)
error = (
"Hashtags must not include '#' and must match rules at "
+ tags[0]
+ "https://docs.joinmastodon.org/user/posting/#hashtags"
+ tags[2]
)
assert ok == "ok"
assert also_ok == "ok"
assert wrong == error
assert also_wrong == error
assert still_wrong == error
def test_init_sanitise_url():
tags = ("\033[96m", "\033[2m", "\033[0m")
wrong = ephemetoot.sanitise_input("http://example.social", "Base URL", tags)
also_wrong = ephemetoot.sanitise_input("http://example.social", "Base URL", tags)
ok = ephemetoot.sanitise_input("example.social", "Base URL", tags)
assert (
wrong
== "Provide full domain without protocol prefix (e.g. \033[2mexample.social\033[0m, not \033[2mhttp://example.social\033[0m)"
)
assert ok == "ok"
def test_init_sanitise_username():
tags = ("\033[96m", "\033[2m", "\033[0m")
wrong = ephemetoot.sanitise_input("@alice", "Username", tags)
ok = ephemetoot.sanitise_input("alice", "Username", tags)
assert wrong == "Do not include '@' in username, please try again"
assert ok == "ok"
def test_init_sanitise_visibility_list():
tags = ("\033[96m", "\033[2m", "\033[0m")
wrong = ephemetoot.sanitise_input("nonexistent", "Visibility to keep", tags)
also_wrong = ephemetoot.sanitise_input("direct public", "Visibility to keep", tags)
ok = ephemetoot.sanitise_input("direct", "Visibility to keep", tags)
also_ok = ephemetoot.sanitise_input("direct, public", "Visibility to keep", tags)
error = (
"Valid values are one or more of 'public', 'unlisted', 'private' or 'direct'"
)
assert ok == "ok"
assert also_ok == "ok"
assert wrong == error
assert also_wrong == error
def test_jsondefault(): def test_jsondefault():
d = ephemetoot.jsondefault(toot.created_at) d = ephemetoot.jsondefault(toot.created_at)
assert d == "2020-05-09T02:17:18.598000+00:00" assert d == "2020-05-09T02:17:18.598000+00:00"