quickjs-tart

quickjs-based runtime for wallet-core logic
Log | Files | Refs | README | LICENSE

smbserver.py (16381B)


      1 #!/usr/bin/env python3
      2 # -*- coding: utf-8 -*-
      3 #
      4 #  Project                     ___| | | |  _ \| |
      5 #                             / __| | | | |_) | |
      6 #                            | (__| |_| |  _ <| |___
      7 #                             \___|\___/|_| \_\_____|
      8 #
      9 # Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
     10 #
     11 # This software is licensed as described in the file COPYING, which
     12 # you should have received as part of this distribution. The terms
     13 # are also available at https://curl.se/docs/copyright.html.
     14 #
     15 # You may opt to use, copy, modify, merge, publish, distribute and/or sell
     16 # copies of the Software, and permit persons to whom the Software is
     17 # furnished to do so, under the terms of the COPYING file.
     18 #
     19 # This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
     20 # KIND, either express or implied.
     21 #
     22 # SPDX-License-Identifier: curl
     23 #
     24 """Server for testing SMB."""
     25 
     26 from __future__ import (absolute_import, division, print_function,
     27                         unicode_literals)
     28 
     29 import argparse
     30 import logging
     31 import os
     32 import signal
     33 import sys
     34 import tempfile
     35 import threading
     36 
     37 # Import our curl test data helper
     38 from util import ClosingFileHandler, TestData
     39 
     40 if sys.version_info.major >= 3:
     41     import configparser
     42 else:
     43     import ConfigParser as configparser
     44 
     45 # impacket needs to be installed in the Python environment
     46 try:
     47     import impacket  # noqa: F401
     48 except ImportError:
     49     sys.stderr.write(
     50         'Warning: Python package impacket is required for smb testing; '
     51         'use pip or your package manager to install it\n')
     52     sys.exit(1)
     53 from impacket import smb as imp_smb
     54 from impacket import smbserver as imp_smbserver
     55 from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_NO_SUCH_FILE,
     56                                 STATUS_SUCCESS)
     57 
     58 log = logging.getLogger(__name__)
     59 SERVER_MAGIC = "SERVER_MAGIC"
     60 TESTS_MAGIC = "TESTS_MAGIC"
     61 VERIFIED_REQ = "verifiedserver"
     62 VERIFIED_RSP = "WE ROOLZ: {pid}\n"
     63 
     64 
     65 class ShutdownHandler(threading.Thread):
     66     """
     67     Cleanly shut down the SMB server.
     68 
     69     This can only be done from another thread while the server is in
     70     serve_forever(), so a thread is spawned here that waits for a shutdown
     71     signal before doing its thing. Use in a with statement around the
     72     serve_forever() call.
     73     """
     74 
     75     def __init__(self, server):
     76         super(ShutdownHandler, self).__init__()
     77         self.server = server
     78         self.shutdown_event = threading.Event()
     79 
     80     def __enter__(self):
     81         self.start()
     82         signal.signal(signal.SIGINT, self._sighandler)
     83         signal.signal(signal.SIGTERM, self._sighandler)
     84 
     85     def __exit__(self, *_):
     86         # Call for shutdown just in case it wasn't done already
     87         self.shutdown_event.set()
     88         # Wait for thread, and therefore also the server, to finish
     89         self.join()
     90         # Uninstall our signal handlers
     91         signal.signal(signal.SIGINT, signal.SIG_DFL)
     92         signal.signal(signal.SIGTERM, signal.SIG_DFL)
     93         # Delete any temporary files created by the server during its run
     94         log.info("Deleting %d temporary file(s)", len(self.server.tmpfiles))
     95         for f in self.server.tmpfiles:
     96             os.unlink(f)
     97 
     98     def _sighandler(self, _signum, _frame):
     99         # Wake up the cleanup task
    100         self.shutdown_event.set()
    101 
    102     def run(self):
    103         # Wait for shutdown signal
    104         self.shutdown_event.wait()
    105         # Notify the server to shut down
    106         self.server.shutdown()
    107 
    108 
    109 def smbserver(options):
    110     """Start up a TCP SMB server that serves forever."""
    111     if options.pidfile:
    112         pid = os.getpid()
    113         # see tests/server/util.c function write_pidfile
    114         if os.name == "nt":
    115             pid += 4194304
    116         with open(options.pidfile, "w") as f:
    117             f.write(str(pid))
    118 
    119     # Here we write a mini config for the server
    120     smb_config = configparser.ConfigParser()
    121     smb_config.add_section("global")
    122     smb_config.set("global", "server_name", "SERVICE")
    123     smb_config.set("global", "server_os", "UNIX")
    124     smb_config.set("global", "server_domain", "WORKGROUP")
    125     smb_config.set("global", "log_file", "None")
    126     smb_config.set("global", "credentials_file", "")
    127 
    128     # We need a share which allows us to test that the server is running
    129     smb_config.add_section("SERVER")
    130     smb_config.set("SERVER", "comment", "server function")
    131     smb_config.set("SERVER", "read only", "yes")
    132     smb_config.set("SERVER", "share type", "0")
    133     smb_config.set("SERVER", "path", SERVER_MAGIC)
    134 
    135     # Have a share for tests.  These files will be autogenerated from the
    136     # test input.
    137     smb_config.add_section("TESTS")
    138     smb_config.set("TESTS", "comment", "tests")
    139     smb_config.set("TESTS", "read only", "yes")
    140     smb_config.set("TESTS", "share type", "0")
    141     smb_config.set("TESTS", "path", TESTS_MAGIC)
    142 
    143     if not options.srcdir or not os.path.isdir(options.srcdir):
    144         raise ScriptError("--srcdir is mandatory")
    145 
    146     test_data_dir = os.path.join(options.srcdir, "data")
    147 
    148     smb_server = TestSmbServer((options.host, options.port),
    149                                config_parser=smb_config,
    150                                test_data_directory=test_data_dir)
    151     log.info("[SMB] setting up SMB server on port %s", options.port)
    152     smb_server.processConfigFile()
    153 
    154     # Start a thread that cleanly shuts down the server on a signal
    155     with ShutdownHandler(smb_server):
    156         # This will block until smb_server.shutdown() is called
    157         smb_server.serve_forever()
    158 
    159     return 0
    160 
    161 
    162 class TestSmbServer(imp_smbserver.SMBSERVER):
    163     """
    164     Test server for SMB which subclasses the impacket SMBSERVER and provides
    165     test functionality.
    166     """
    167 
    168     def __init__(self,
    169                  address,
    170                  config_parser=None,
    171                  test_data_directory=None):
    172         imp_smbserver.SMBSERVER.__init__(self,
    173                                          address,
    174                                          config_parser=config_parser)
    175         self.tmpfiles = []
    176 
    177         # Set up a test data object so we can get test data later.
    178         self.ctd = TestData(test_data_directory)
    179 
    180         # Override smbComNtCreateAndX so we can pretend to have files which
    181         # don't exist.
    182         self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX,
    183                             self.create_and_x)
    184 
    185     def create_and_x(self, conn_id, smb_server, smb_command, recv_packet):
    186         """
    187         Our version of smbComNtCreateAndX looks for special test files and
    188         fools the rest of the framework into opening them as if they were
    189         normal files.
    190         """
    191         conn_data = smb_server.getConnectionData(conn_id)
    192 
    193         # Wrap processing in a try block which allows us to throw SmbError
    194         # to control the flow.
    195         try:
    196             ncax_parms = imp_smb.SMBNtCreateAndX_Parameters(
    197                 smb_command["Parameters"])
    198 
    199             path = self.get_share_path(conn_data,
    200                                        ncax_parms["RootFid"],
    201                                        recv_packet["Tid"])
    202             log.info("[SMB] Requested share path: %s", path)
    203 
    204             disposition = ncax_parms["Disposition"]
    205             log.debug("[SMB] Requested disposition: %s", disposition)
    206 
    207             # Currently we only support reading files.
    208             if disposition != imp_smb.FILE_OPEN:
    209                 raise SmbError(STATUS_ACCESS_DENIED,
    210                                    "Only support reading files")
    211 
    212             # Check to see if the path we were given is actually a
    213             # magic path which needs generating on the fly.
    214             if path not in [SERVER_MAGIC, TESTS_MAGIC]:
    215                 # Pass the command onto the original handler.
    216                 return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id,
    217                                                                     smb_server,
    218                                                                     smb_command,
    219                                                                     recv_packet)
    220 
    221             flags2 = recv_packet["Flags2"]
    222             ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2,
    223                                                      data=smb_command[
    224                                                          "Data"])
    225             requested_file = imp_smbserver.decodeSMBString(
    226                 flags2,
    227                 ncax_data["FileName"])
    228             log.debug("[SMB] User requested file '%s'", requested_file)
    229 
    230             if path == SERVER_MAGIC:
    231                 fid, full_path = self.get_server_path(requested_file)
    232             else:
    233                 assert path == TESTS_MAGIC
    234                 fid, full_path = self.get_test_path(requested_file)
    235 
    236             self.tmpfiles.append(full_path)
    237 
    238             resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters()
    239             resp_data = ""
    240 
    241             # Simple way to generate a fid
    242             if len(conn_data["OpenedFiles"]) == 0:
    243                 fakefid = 1
    244             else:
    245                 fakefid = conn_data["OpenedFiles"].keys()[-1] + 1
    246             resp_parms["Fid"] = fakefid
    247             resp_parms["CreateAction"] = disposition
    248 
    249             if os.path.isdir(path):
    250                 resp_parms[
    251                     "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY
    252                 resp_parms["IsDirectory"] = 1
    253             else:
    254                 resp_parms["IsDirectory"] = 0
    255                 resp_parms["FileAttributes"] = ncax_parms["FileAttributes"]
    256 
    257             # Get this file's information
    258             resp_info, error_code = imp_smbserver.queryPathInformation(
    259                 os.path.dirname(full_path), os.path.basename(full_path),
    260                 level=imp_smb.SMB_QUERY_FILE_ALL_INFO)
    261 
    262             if error_code != STATUS_SUCCESS:
    263                 raise SmbError(error_code, "Failed to query path info")
    264 
    265             resp_parms["CreateTime"] = resp_info["CreationTime"]
    266             resp_parms["LastAccessTime"] = resp_info[
    267                 "LastAccessTime"]
    268             resp_parms["LastWriteTime"] = resp_info["LastWriteTime"]
    269             resp_parms["LastChangeTime"] = resp_info[
    270                 "LastChangeTime"]
    271             resp_parms["FileAttributes"] = resp_info[
    272                 "ExtFileAttributes"]
    273             resp_parms["AllocationSize"] = resp_info[
    274                 "AllocationSize"]
    275             resp_parms["EndOfFile"] = resp_info["EndOfFile"]
    276 
    277             # Let's store the fid for the connection
    278             # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode))
    279             conn_data["OpenedFiles"][fakefid] = {}
    280             conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid
    281             conn_data["OpenedFiles"][fakefid]["FileName"] = path
    282             conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False
    283 
    284         except SmbError as s:
    285             log.debug("[SMB] SmbError hit: %s", s)
    286             error_code = s.error_code
    287             resp_parms = ""
    288             resp_data = ""
    289 
    290         resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX)
    291         resp_cmd["Parameters"] = resp_parms
    292         resp_cmd["Data"] = resp_data
    293         smb_server.setConnectionData(conn_id, conn_data)
    294 
    295         return [resp_cmd], None, error_code
    296 
    297     def get_share_path(self, conn_data, root_fid, tid):
    298         conn_shares = conn_data["ConnectedShares"]
    299 
    300         if tid in conn_shares:
    301             if root_fid > 0:
    302                 # If we have a rootFid, the path is relative to that fid
    303                 path = conn_data["OpenedFiles"][root_fid]["FileName"]
    304                 log.debug("RootFid present %s!" % path)
    305             else:
    306                 if "path" in conn_shares[tid]:
    307                     path = conn_shares[tid]["path"]
    308                 else:
    309                     raise SmbError(STATUS_ACCESS_DENIED,
    310                                        "Connection share had no path")
    311         else:
    312             raise SmbError(imp_smbserver.STATUS_SMB_BAD_TID,
    313                                "TID was invalid")
    314 
    315         return path
    316 
    317     def get_server_path(self, requested_filename):
    318         log.debug("[SMB] Get server path '%s'", requested_filename)
    319 
    320         if requested_filename not in [VERIFIED_REQ]:
    321             raise SmbError(STATUS_NO_SUCH_FILE, "Couldn't find the file")
    322 
    323         fid, filename = tempfile.mkstemp()
    324         log.debug("[SMB] Created %s (%d) for storing '%s'",
    325                   filename, fid, requested_filename)
    326 
    327         contents = ""
    328 
    329         if requested_filename == VERIFIED_REQ:
    330             log.debug("[SMB] Verifying server is alive")
    331             pid = os.getpid()
    332             # see tests/server/util.c function write_pidfile
    333             if os.name == "nt":
    334                 pid += 4194304
    335             contents = VERIFIED_RSP.format(pid=pid).encode('utf-8')
    336 
    337         self.write_to_fid(fid, contents)
    338         return fid, filename
    339 
    340     def write_to_fid(self, fid, contents):
    341         # Write the contents to file descriptor
    342         os.write(fid, contents)
    343         os.fsync(fid)
    344 
    345         # Rewind the file to the beginning so a read gets us the contents
    346         os.lseek(fid, 0, os.SEEK_SET)
    347 
    348     def get_test_path(self, requested_filename):
    349         log.info("[SMB] Get reply data from 'test%s'", requested_filename)
    350 
    351         fid, filename = tempfile.mkstemp()
    352         log.debug("[SMB] Created %s (%d) for storing test '%s'",
    353                   filename, fid, requested_filename)
    354 
    355         try:
    356             contents = self.ctd.get_test_data(requested_filename).encode('utf-8')
    357             self.write_to_fid(fid, contents)
    358             return fid, filename
    359 
    360         except Exception:
    361             log.exception("Failed to make test file")
    362             raise SmbError(STATUS_NO_SUCH_FILE, "Failed to make test file")
    363 
    364 
    365 class SmbError(Exception):
    366     def __init__(self, error_code, error_message):
    367         super(SmbError, self).__init__(error_message)
    368         self.error_code = error_code
    369 
    370 
    371 class ScriptRC(object):
    372     """Enum for script return codes."""
    373 
    374     SUCCESS = 0
    375     FAILURE = 1
    376     EXCEPTION = 2
    377 
    378 
    379 class ScriptError(Exception):
    380     pass
    381 
    382 
    383 def get_options():
    384     parser = argparse.ArgumentParser()
    385 
    386     parser.add_argument("--port", action="store", default=9017,
    387                       type=int, help="port to listen on")
    388     parser.add_argument("--host", action="store", default="127.0.0.1",
    389                       help="host to listen on")
    390     parser.add_argument("--verbose", action="store", type=int, default=0,
    391                         help="verbose output")
    392     parser.add_argument("--pidfile", action="store",
    393                         help="file name for the PID")
    394     parser.add_argument("--logfile", action="store",
    395                         help="file name for the log")
    396     parser.add_argument("--srcdir", action="store", help="test directory")
    397     parser.add_argument("--id", action="store", help="server ID")
    398     parser.add_argument("--ipv4", action="store_true", default=0,
    399                         help="IPv4 flag")
    400 
    401     return parser.parse_args()
    402 
    403 
    404 def setup_logging(options):
    405     """Set up logging from the command line options."""
    406     root_logger = logging.getLogger()
    407     add_stdout = False
    408 
    409     formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
    410 
    411     # Write out to a logfile
    412     if options.logfile:
    413         handler = ClosingFileHandler(options.logfile)
    414         handler.setFormatter(formatter)
    415         handler.setLevel(logging.DEBUG)
    416         root_logger.addHandler(handler)
    417     else:
    418         # The logfile wasn't specified. Add a stdout logger.
    419         add_stdout = True
    420 
    421     if options.verbose:
    422         # Add a stdout logger as well in verbose mode
    423         root_logger.setLevel(logging.DEBUG)
    424         add_stdout = True
    425     else:
    426         root_logger.setLevel(logging.WARNING)
    427 
    428     if add_stdout:
    429         stdout_handler = logging.StreamHandler(sys.stdout)
    430         stdout_handler.setFormatter(formatter)
    431         stdout_handler.setLevel(logging.DEBUG)
    432         root_logger.addHandler(stdout_handler)
    433 
    434 
    435 if __name__ == '__main__':
    436     # Get the options from the user.
    437     options = get_options()
    438 
    439     # Setup logging using the user options
    440     setup_logging(options)
    441 
    442     # Run main script.
    443     try:
    444         rc = smbserver(options)
    445     except Exception:
    446         log.exception('Error in SMB server')
    447         rc = ScriptRC.EXCEPTION
    448 
    449     if options.pidfile and os.path.isfile(options.pidfile):
    450         os.unlink(options.pidfile)
    451 
    452     log.info("[SMB] Returning %d", rc)
    453     sys.exit(rc)